[Mlir-commits] [mlir] [mlir][tosa] Add support for fp32 accumulators in fp8 convolution operations (PR #175517)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 12 03:04:02 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
As per the spec change: https://github.com/arm/tosa-specification/pull/16
---
Patch is 31.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175517.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+24-8)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+15-7)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+40)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+11-11)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+36)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+40)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+40)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index fd55cd82b8663..b1b602be9dfb2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -505,10 +505,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
@@ -520,10 +524,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
@@ -535,10 +543,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
@@ -602,10 +614,14 @@ extensionComplianceMap = {
SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
- SpecificationVersion::V_1_0}}},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5656f3de698c5..6e934526d9035 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -731,22 +731,30 @@ static LogicalResult verifyConvOpModes(T op) {
auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
- return op.emitOpError("accumulator type for i8 tensor is not i32");
+ return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
+ << accType;
if (inputEType.isInteger(16) && !accType.isInteger(48))
- return op.emitOpError("accumulator type for i16 tensor is not i48");
+ return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
+ << accType;
- if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
- return op.emitOpError("accumulator type for f8 tensor is not f16");
+ if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
+ !(accType.isF16() || accType.isF32()))
+ return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
+ << accType;
if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
- return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
+ return op.emitOpError(
+ "accumulator type for f16 tensor is not f16/f32, got ")
+ << accType;
if (inputEType.isBF16() && !accType.isF32())
- return op.emitOpError("accumulator type for bf16 tensor is not f32");
+ return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
+ << accType;
if (inputEType.isF32() && !accType.isF32())
- return op.emitOpError("accumulator type for f32 tensor is not f32");
+ return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
+ << accType;
auto resultEType =
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 9ea224a2854e5..c648fb5bdf8e3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -672,6 +672,20 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2
// -----
+// CHECK-LABEL: @conv2d_f8_f32_acc
+func.func @conv2d_f8_f32_acc(%input: tensor<1x49x42x27xf8E5M2>, %weights: tensor<28x3x3x27xf8E5M2>, %bias: tensor<28xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf8E5M2>, tensor<28x3x3x27xf8E5M2>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf8E5M2>, tensor<28x3x3x27xf8E5M2>, tensor<28xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x45x40x28xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -874,6 +888,18 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
// -----
+// CHECK-LABEL: @depthwise_conv2d_f8_f32_acc
+func.func @depthwise_conv2d_f8_f32_acc(%arg0 : tensor<1x7x5x3xf8E5M2>, %arg1 : tensor<3x1x3x11xf8E5M2>, %arg2 : tensor<33xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf8E5M2>, tensor<3x1x3x11xf8E5M2>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf8E5M2>, tensor<3x1x3x11xf8E5M2>, tensor<33xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x5x5x33xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
@@ -958,6 +984,20 @@ func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tenso
// -----
+// CHECK-LABEL: @conv3d_f8_f32_acc
+func.func @conv3d_f8_f32_acc(%input: tensor<1x49x48x47x27xf8E5M2>, %weights: tensor<43x3x4x5x27xf8E5M2>, %bias: tensor<43xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf8E5M2>, tensor<3x4x5x27x43xf8E5M2>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf8E5M2>, tensor<43x3x4x5x27xf8E5M2>, tensor<43xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x47x45x43x43xf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_transpose
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e8206c24f1507..fb79070b3a8a4 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -76,7 +76,7 @@ func.func @test_conv2d_weight_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
@@ -87,7 +87,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> {
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48, got 'f16'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x27x27x16xi16>
return %0 : tensor<1x27x27x16xi16>
@@ -97,7 +97,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<16x3x3x4xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf8E5M2>} : () -> tensor<1xf8E5M2>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16/f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf8E5M2>, tensor<16x3x3x4xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -107,7 +107,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<1
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<16x3x3x4xf8E4M3>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf8E4M3>} : () -> tensor<1xf8E4M3>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16/f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf8E4M3>, tensor<16x3x3x4xf8E4M3>, tensor<16xf16>, tensor<1xf8E4M3>, tensor<1xf8E4M3>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -117,7 +117,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<1
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf16>} : () -> tensor<1xf16>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -127,7 +127,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x3x3x4xbf16>, %arg2: tensor<16xbf16>) -> tensor<1x27x27x16xbf16> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xbf16>} : () -> tensor<1xbf16>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xbf16>, tensor<16x3x3x4xbf16>, tensor<16xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1x27x27x16xbf16>
return %0 : tensor<1x27x27x16xbf16>
@@ -137,7 +137,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
%zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32, got 'i32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x27x27x16xf32>
return %0 : tensor<1x27x27x16xf32>
@@ -147,7 +147,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3
func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.conv3d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>}
: (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi8>
return %0 : tensor<1x4x8x21x34xi8>
@@ -157,7 +157,7 @@ func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x
func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi8>
return %0 : tensor<1x4x4x8xi8>
}
@@ -166,7 +166,7 @@ func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tens
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
return %0 : tensor<1x32x32x16xi8>
}
@@ -247,7 +247,7 @@ func.func @test_transpose_conv2d_invalid_bias(%arg0: tensor<1x32x32x8xf32>, %arg
// CHECK-LABEL: conv2d_quant_any_acc
func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>> {
%zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+ // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32, got 'f32'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>>
return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 276eac4d6166d..a6aaebb8b6a10 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -981,6 +981,15 @@ func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1
return %0 : tensor<1x4x4x8xf16>
}
+// -----
+// CHECK-LABEL: conv2d_f8E5M2_acc32
+func.func @test_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %we...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/175517
More information about the Mlir-commits
mailing list