[Mlir-commits] [mlir] [mlir][tosa] Add support for fp32 accumulators in fp8 convolution operations (PR #175517)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 12 03:04:02 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

As per the spec change: https://github.com/arm/tosa-specification/pull/16

---

Patch is 31.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175517.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+24-8) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+15-7) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+40) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+11-11) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+36) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+40) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+40) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index fd55cd82b8663..b1b602be9dfb2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -505,10 +505,14 @@ extensionComplianceMap = {
          SpecificationVersion::V_1_0}}},
       {{Extension::fp8e4m3},
        {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e5m2},
        {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
          SpecificationVersion::V_1_0}}}}},
@@ -520,10 +524,14 @@ extensionComplianceMap = {
          SpecificationVersion::V_1_0}}},
       {{Extension::fp8e4m3},
        {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e5m2},
        {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
          SpecificationVersion::V_1_0}}}}},
@@ -535,10 +543,14 @@ extensionComplianceMap = {
          SpecificationVersion::V_1_0}}},
       {{Extension::fp8e4m3},
        {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e5m2},
        {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
          SpecificationVersion::V_1_0}}}}},
@@ -602,10 +614,14 @@ extensionComplianceMap = {
          SpecificationVersion::V_1_0}}},
       {{Extension::fp8e4m3},
        {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e5m2},
        {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
-         SpecificationVersion::V_1_0}}},
+          SpecificationVersion::V_1_0},
+         {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp32T, fp16T},
+          SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
          SpecificationVersion::V_1_0}}}}},
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5656f3de698c5..6e934526d9035 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -731,22 +731,30 @@ static LogicalResult verifyConvOpModes(T op) {
 
   auto accType = op.getAccType();
   if (inputEType.isInteger(8) && !accType.isInteger(32))
-    return op.emitOpError("accumulator type for i8 tensor is not i32");
+    return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
+           << accType;
 
   if (inputEType.isInteger(16) && !accType.isInteger(48))
-    return op.emitOpError("accumulator type for i16 tensor is not i48");
+    return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
+           << accType;
 
-  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
-    return op.emitOpError("accumulator type for f8 tensor is not f16");
+  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
+      !(accType.isF16() || accType.isF32()))
+    return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
+           << accType;
 
   if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
-    return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
+    return op.emitOpError(
+               "accumulator type for f16 tensor is not f16/f32, got ")
+           << accType;
 
   if (inputEType.isBF16() && !accType.isF32())
-    return op.emitOpError("accumulator type for bf16 tensor is not f32");
+    return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
+           << accType;
 
   if (inputEType.isF32() && !accType.isF32())
-    return op.emitOpError("accumulator type for f32 tensor is not f32");
+    return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
+           << accType;
 
   auto resultEType =
       llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 9ea224a2854e5..c648fb5bdf8e3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -672,6 +672,20 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2
 
 // -----
 
+// CHECK-LABEL: @conv2d_f8_f32_acc
+func.func @conv2d_f8_f32_acc(%input: tensor<1x49x42x27xf8E5M2>, %weights: tensor<28x3x3x27xf8E5M2>, %bias: tensor<28xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf8E5M2>, tensor<28x3x3x27xf8E5M2>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+  %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf8E5M2>, tensor<28x3x3x27xf8E5M2>, tensor<28xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x45x40x28xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
@@ -874,6 +888,18 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
 
 // -----
 
+// CHECK-LABEL: @depthwise_conv2d_f8_f32_acc
+func.func @depthwise_conv2d_f8_f32_acc(%arg0 : tensor<1x7x5x3xf8E5M2>, %arg1 : tensor<3x1x3x11xf8E5M2>, %arg2 : tensor<33xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf8E5M2>, tensor<3x1x3x11xf8E5M2>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf8E5M2>, tensor<3x1x3x11xf8E5M2>, tensor<33xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x5x5x33xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 
@@ -958,6 +984,20 @@ func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tenso
 
 // -----
 
+// CHECK-LABEL: @conv3d_f8_f32_acc
+func.func @conv3d_f8_f32_acc(%input: tensor<1x49x48x47x27xf8E5M2>, %weights: tensor<43x3x4x5x27xf8E5M2>, %bias: tensor<43xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf8E5M2>, tensor<3x4x5x27x43xf8E5M2>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf8E5M2>, tensor<43x3x4x5x27xf8E5M2>, tensor<43xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x47x45x43x43xf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_transpose
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
 func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e8206c24f1507..fb79070b3a8a4 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -76,7 +76,7 @@ func.func @test_conv2d_weight_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
   %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
@@ -87,7 +87,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> {
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   %weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48, got 'f16'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x27x27x16xi16>
   return %0 : tensor<1x27x27x16xi16>
@@ -97,7 +97,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<16x3x3x4xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
   %zp = "tosa.const"() {values = dense<0.0> : tensor<1xf8E5M2>} : () -> tensor<1xf8E5M2>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16/f32, got 'i32'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xf8E5M2>, tensor<16x3x3x4xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x27x27x16xf16>
   return %0 : tensor<1x27x27x16xf16>
@@ -107,7 +107,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<1
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<16x3x3x4xf8E4M3>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
   %zp = "tosa.const"() {values = dense<0.0> : tensor<1xf8E4M3>} : () -> tensor<1xf8E4M3>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16/f32, got 'i32'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xf8E4M3>, tensor<16x3x3x4xf8E4M3>, tensor<16xf16>, tensor<1xf8E4M3>, tensor<1xf8E4M3>) -> tensor<1x27x27x16xf16>
   return %0 : tensor<1x27x27x16xf16>
@@ -117,7 +117,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<1
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
   %zp = "tosa.const"() {values = dense<0.0> : tensor<1xf16>} : () -> tensor<1xf16>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32, got 'i32'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x27x27x16xf16>
   return %0 : tensor<1x27x27x16xf16>
@@ -127,7 +127,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x3x3x4xbf16>, %arg2: tensor<16xbf16>) -> tensor<1x27x27x16xbf16> {
   %zp = "tosa.const"() {values = dense<0.0> : tensor<1xbf16>} : () -> tensor<1xbf16>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32, got 'i32'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xbf16>, tensor<16x3x3x4xbf16>, tensor<16xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1x27x27x16xbf16>
   return %0 : tensor<1x27x27x16xbf16>
@@ -137,7 +137,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
   %zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32, got 'i32'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x27x27x16xf32>
   return %0 : tensor<1x27x27x16xf32>
@@ -147,7 +147,7 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3
 
 func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> {
   %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}}
+  // expected-error at +1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32, got 'f16'}}
   %0 = tosa.conv3d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>}
            : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi8>
   return %0 : tensor<1x4x8x21x34xi8>
@@ -157,7 +157,7 @@ func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x
 
 func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> {
   %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}}
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
   %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi8>
   return %0 : tensor<1x4x4x8xi8>
 }
@@ -166,7 +166,7 @@ func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tens
 
 func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
   %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
+  // expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32, got 'f16'}}
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
   return %0 : tensor<1x32x32x16xi8>
 }
@@ -247,7 +247,7 @@ func.func @test_transpose_conv2d_invalid_bias(%arg0: tensor<1x32x32x8xf32>, %arg
 // CHECK-LABEL: conv2d_quant_any_acc
 func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>> {
   %zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32, got 'f32'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>>
   return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
 }
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 276eac4d6166d..a6aaebb8b6a10 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -981,6 +981,15 @@ func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1
   return %0 : tensor<1x4x4x8xf16>
 }
 
+// -----
+// CHECK-LABEL: conv2d_f8E5M2_acc32
+func.func @test_conv2d_f8E5M2_acc32(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %we...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/175517


More information about the Mlir-commits mailing list