[Mlir-commits] [mlir] [mlir][tosa] Add support for fp32 accumulators in fp8 convolution operations (PR #175517)
Luke Hutton
llvmlistbot at llvm.org
Mon Jan 12 05:51:28 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/175517
>From 2baf2d40377224ff1035f793dd7144dae60fa5c3 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 12 Jan 2026 10:35:03 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add support for fp32 accumulators in fp8
convolution operations
As per the spec change: https://github.com/arm/tosa-specification/pull/16
Change-Id: Ie05dc64173f7d244cf265fabbce94684007b4af0
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 32 +++++++++++----
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 22 ++++++----
.../TosaToLinalg/tosa-to-linalg-named.mlir | 40 +++++++++++++++++++
mlir/test/Dialect/Tosa/invalid.mlir | 22 +++++-----
mlir/test/Dialect/Tosa/ops.mlir | 36 +++++++++++++++++
.../tosa-validation-version-1p0-invalid.mlir | 40 +++++++++++++++++++
.../tosa-validation-version-1p1-valid.mlir | 40 +++++++++++++++++++
7 files changed, 206 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index fd55cd82b8663..1ee7c38a5a03f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -505,10 +505,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
@@ -520,10 +524,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
@@ -535,10 +543,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
@@ -602,10 +614,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5656f3de698c5..6e934526d9035 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -731,22 +731,30 @@ static LogicalResult verifyConvOpModes(T op) {
auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
- return op.emitOpError("accumulator type for i8 tensor is not i32");
+ return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
+ << accType;
if (inputEType.isInteger(16) && !accType.isInteger(48))
- return op.emitOpError("accumulator type for i16 tensor is not i48");
+ return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
+ << accType;
- if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
- return op.emitOpError("accumulator type for f8 tensor is not f16");
+ if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
+ !(accType.isF16() || accType.isF32()))
+ return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
+ << accType;
if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
- return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
+ return op.emitOpError(
+ "accumulator type for f16 tensor is not f16/f32, got ")
+ << accType;
if (inputEType.isBF16() && !accType.isF32())
- return op.emitOpError("accumulator type for bf16 tensor is not f32");
+ return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
+ << accType;
if (inputEType.isF32() && !accType.isF32())
- return op.emitOpError("accumulator type for f32 tensor is not f32");
+ return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
+ << accType;
auto resultEType =
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 9ea224a2854e5..c648fb5bdf8e3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -672,6 +672,20 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2
// -----
+// CHECK-LABEL: @conv2d_f8_f32_acc
+func.func @conv2d_f8_f32_acc(%input: tensor<1x49x42x27xf8E5M2>, %weights: tensor<28x3x3x27xf8E5M2>, %bias: tensor<28xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf8E5M2>, tensor<28x3x3x27xf8E5M2>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf8E5M2>, tensor<28x3x3x27xf8E5M2>, tensor<28xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x45x40x28xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -874,6 +888,18 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
// -----
+// CHECK-LABEL: @depthwise_conv2d_f8_f32_acc
+func.func @depthwise_conv2d_f8_f32_acc(%arg0 : tensor<1x7x5x3xf8E5M2>, %arg1 : tensor<3x1x3x11xf8E5M2>, %arg2 : tensor<33xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf8E5M2>, tensor<3x1x3x11xf8E5M2>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf8E5M2>, tensor<3x1x3x11xf8E5M2>, tensor<33xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x5x5x33xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
@@ -958,6 +984,20 @@ func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tenso
// -----
+// CHECK-LABEL: @conv3d_f8_f32_acc
+func.func @conv3d_f8_f32_acc(%input: tensor<1x49x48x47x27xf8E5M2>, %weights: tensor<43x3x4x5x27xf8E5M2>, %bias: tensor<43xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf8E5M2>, tensor<3x4x5x27x43xf8E5M2>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf8E5M2>, tensor<43x3x4x5x27xf8E5M2>, tensor<43xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x47x45x43x43xf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_transpose
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e8206c24f1507..fb79070b3a8a4 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -76,7 +76,7 @@ func.func @test_conv2d_weight_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
@@ -87,7 +87,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> {
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48, got 'f16'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x27x27x16xi16>
return %0 : tensor<1x27x27x16xi16>
@@ -97,7 +97,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<16x3x3x4xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf8E5M2>} : () -> tensor<1xf8E5M2>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16/f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf8E5M2>, tensor<16x3x3x4xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -107,7 +107,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<1
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<16x3x3x4xf8E4M3>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf8E4M3>} : () -> tensor<1xf8E4M3>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16/f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf8E4M3>, tensor<16x3x3x4xf8E4M3>, tensor<16xf16>, tensor<1xf8E4M3>, tensor<1xf8E4M3>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -117,7 +117,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<1
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf16>} : () -> tensor<1xf16>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -127,7 +127,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x3x3x4xbf16>, %arg2: tensor<16xbf16>) -> tensor<1x27x27x16xbf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xbf16>} : () -> tensor<1xbf16>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xbf16>, tensor<16x3x3x4xbf16>, tensor<16xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1x27x27x16xbf16>
return %0 : tensor<1x27x27x16xbf16>
@@ -137,7 +137,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x27x27x16xf32>
return %0 : tensor<1x27x27x16xf32>
@@ -147,7 +147,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3
func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.conv3d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>}
: (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi8>
return %0 : tensor<1x4x8x21x34xi8>
@@ -157,7 +157,7 @@ func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x
func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi8>
return %0 : tensor<1x4x4x8xi8>
}
@@ -166,7 +166,7 @@ func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tens
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
return %0 : tensor<1x32x32x16xi8>
}
@@ -247,7 +247,7 @@ func.func @test_transpose_conv2d_invalid_bias(%arg0: tensor<1x32x32x8xf32>, %arg
// CHECK-LABEL: conv2d_quant_any_acc
func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>> {
%zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32, got 'f32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>>
return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 276eac4d6166d..a6aaebb8b6a10 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -981,6 +981,15 @@ func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1
return %0 : tensor<1x4x4x8xf16>
}
+// -----
+// CHECK-LABEL: conv2d_f8E5M2_acc32
+func.func @test_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
// -----
// CHECK-LABEL: conv3d_f8E5M2
func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> {
@@ -988,6 +997,15 @@ func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<3
return %0 : tensor<1x4x8x21x34xf16>
}
+// -----
+// CHECK-LABEL: conv3d_f8E5M2_acc32
+func.func @test_conv3d_f8E5M2_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
+ return %0 : tensor<1x4x8x21x34xf16>
+}
+
// -----
// CHECK-LABEL: depthwise_conv2d_f8E5M2
func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16> {
@@ -995,6 +1013,15 @@ func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: te
return %0 : tensor<1x4x4x8xf16>
}
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E5M2_acc32
+func.func @test_depthwise_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
// -----
// CHECK-LABEL: test_matmul_f8E5M2
func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
@@ -1027,6 +1054,15 @@ func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1:
return %0 : tensor<1x32x32x16xf16>
}
+// -----
+// CHECK-LABEL: transpose_conv2d_f8E5M2_acc32
+func.func @test_transpose_conv2d_f8E5M2_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
+ return %0 : tensor<1x32x32x16xf16>
+}
+
// -----
// CHECK-LABEL: const_f8E5M2
func.func @test_const_f8E5M2(%arg0 : index) -> tensor<4xf8E5M2> {
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
index fa40913834b5b..9ecede6782da7 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -22,6 +22,46 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
// -----
+func.func @test_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error at +1 {{'tosa.conv2d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+
+func.func @test_conv3d_fp8_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error at +1 {{'tosa.conv3d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
+ return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error at +1 {{'tosa.depthwise_conv2d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error at +1 {{'tosa.transpose_conv2d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
+ return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+
func.func @test_dyanmic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
// expected-error at +1 {{'tosa.argmax' op failed level check: operand shape dimension cannot be dynamic when targeting TOSA specification version 1.0 or below}}
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 63379ed8d8a4d..56ea20f4fbd40 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -22,6 +22,46 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
// -----
+// CHECK-LABEL: test_conv2d_fp8_acc32
+func.func @test_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv3d_fp8_acc32
+func.func @test_conv3d_fp8_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
+ return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_depthwise_conv2d_fp8_acc32
+func.func @test_depthwise_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_transpose_conv2d_fp8_acc32
+func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
+ return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+
// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3
func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
>From 559bf4a635673181b2ed14ca3f7384b95ed7ca31 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 12 Jan 2026 13:47:44 +0000
Subject: [PATCH 2/2] Update with fp32 output type
Update to align with the specification change.
Change-Id: Ie4e57c8007a2be7d1179c440e6a93afcee4b819d
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 16 ++++++-------
.../TosaToLinalg/tosa-to-linalg-named.mlir | 19 +++++----------
mlir/test/Dialect/Tosa/ops.mlir | 24 +++++++++----------
.../tosa-validation-version-1p0-invalid.mlir | 24 +++++++++----------
.../tosa-validation-version-1p1-valid.mlir | 24 +++++++++----------
5 files changed, 50 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 1ee7c38a5a03f..9e6471aa7d04e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -506,12 +506,12 @@ extensionComplianceMap = {
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ {{fp8e4m3T, fp8e4m3T, fp32T, fp8e4m3T, fp8e4m3T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ {{fp8e5m2T, fp8e5m2T, fp32T, fp8e5m2T, fp8e5m2T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
@@ -525,12 +525,12 @@ extensionComplianceMap = {
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ {{fp8e4m3T, fp8e4m3T, fp32T, fp8e4m3T, fp8e4m3T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ {{fp8e5m2T, fp8e5m2T, fp32T, fp8e5m2T, fp8e5m2T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
@@ -544,12 +544,12 @@ extensionComplianceMap = {
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ {{fp8e4m3T, fp8e4m3T, fp32T, fp8e4m3T, fp8e4m3T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ {{fp8e5m2T, fp8e5m2T, fp32T, fp8e5m2T, fp8e5m2T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
@@ -615,12 +615,12 @@ extensionComplianceMap = {
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ {{fp8e4m3T, fp8e4m3T, fp32T, fp8e4m3T, fp8e4m3T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
SpecificationVersion::V_1_0},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ {{fp8e5m2T, fp8e5m2T, fp32T, fp8e5m2T, fp8e5m2T, fp32T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index c648fb5bdf8e3..27097ffc4ff84 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -659,14 +659,11 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
// -----
// CHECK-LABEL: @conv2d_f16_f32_acc
-func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
+func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
- // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
- // CHECK: arith.extf %{{.*}} : f16 to f32
// CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
- // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
- %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf32>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf32>
return
}
@@ -877,12 +874,11 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor
// -----
// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
-func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
+func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
// CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
- // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
- %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf32>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf32>
return
}
@@ -971,14 +967,11 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<43x3x4x5
// -----
// CHECK-LABEL: @conv3d_f16_f32_acc
-func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<43x3x4x5x27xf16>, %bias: tensor<43xf16>) -> () {
+func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<43x3x4x5x27xf16>, %bias: tensor<43xf32>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
- // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
- // CHECK: arith.extf %{{.*}} : f16 to f32
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
- // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
- %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<43x3x4x5x27xf16>, tensor<43xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x43xf16>
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<43x3x4x5x27xf16>, tensor<43xf32>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x43xf32>
return
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a6aaebb8b6a10..01a47c52fc681 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -983,11 +983,11 @@ func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1
// -----
// CHECK-LABEL: conv2d_f8E5M2_acc32
-func.func @test_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+func.func @test_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
- return %0 : tensor<1x4x4x8xf16>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf32>
+ return %0 : tensor<1x4x4x8xf32>
}
// -----
@@ -999,11 +999,11 @@ func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<3
// -----
// CHECK-LABEL: conv3d_f8E5M2_acc32
-func.func @test_conv3d_f8E5M2_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+func.func @test_conv3d_f8E5M2_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
- return %0 : tensor<1x4x8x21x34xf16>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf32>
+ return %0 : tensor<1x4x8x21x34xf32>
}
// -----
@@ -1015,11 +1015,11 @@ func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: te
// -----
// CHECK-LABEL: depthwise_conv2d_f8E5M2_acc32
-func.func @test_depthwise_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+func.func @test_depthwise_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
- return %0 : tensor<1x4x4x8xf16>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf32>
+ return %0 : tensor<1x4x4x8xf32>
}
// -----
@@ -1056,11 +1056,11 @@ func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1:
// -----
// CHECK-LABEL: transpose_conv2d_f8E5M2_acc32
-func.func @test_transpose_conv2d_f8E5M2_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+func.func @test_transpose_conv2d_f8E5M2_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
- return %0 : tensor<1x32x32x16xf16>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
index 9ecede6782da7..fe38e2f61b2e8 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -22,42 +22,42 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
// -----
-func.func @test_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+func.func @test_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
// expected-error at +1 {{'tosa.conv2d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
- return %0 : tensor<1x4x4x8xf16>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf32>
+ return %0 : tensor<1x4x4x8xf32>
}
// -----
-func.func @test_conv3d_fp8_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+func.func @test_conv3d_fp8_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
// expected-error at +1 {{'tosa.conv3d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
- %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
- return %0 : tensor<1x4x8x21x34xf16>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf32>
+ return %0 : tensor<1x4x8x21x34xf32>
}
// -----
-func.func @test_depthwise_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+func.func @test_depthwise_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
// expected-error at +1 {{'tosa.depthwise_conv2d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
- return %0 : tensor<1x4x4x8xf16>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf32>
+ return %0 : tensor<1x4x4x8xf32>
}
// -----
-func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
// expected-error at +1 {{'tosa.transpose_conv2d' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
- return %0 : tensor<1x32x32x16xf16>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 56ea20f4fbd40..f2b4429473ff9 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -23,41 +23,41 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
// -----
// CHECK-LABEL: test_conv2d_fp8_acc32
-func.func @test_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+func.func @test_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
- return %0 : tensor<1x4x4x8xf16>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf32>
+ return %0 : tensor<1x4x4x8xf32>
}
// -----
// CHECK-LABEL: test_conv3d_fp8_acc32
-func.func @test_conv3d_fp8_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+func.func @test_conv3d_fp8_acc32(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
- return %0 : tensor<1x4x8x21x34xf16>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf32>
+ return %0 : tensor<1x4x8x21x34xf32>
}
// -----
// CHECK-LABEL: test_depthwise_conv2d_fp8_acc32
-func.func @test_depthwise_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+func.func @test_depthwise_conv2d_fp8_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
- return %0 : tensor<1x4x4x8xf16>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf32>
+ return %0 : tensor<1x4x4x8xf32>
}
// -----
// CHECK-LABEL: test_transpose_conv2d_fp8_acc32
-func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
- return %0 : tensor<1x32x32x16xf16>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf32>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
}
// -----
More information about the Mlir-commits
mailing list