[Mlir-commits] [mlir] [mlir][spirv] Add SPV_EXT_FP8 type support to SPIR-V TOSA ops (PR #193199)

Davide Grohmann llvmlistbot at llvm.org
Wed Apr 22 01:11:11 PDT 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/193199

>From 3e1665819a086dec5c304d5f749768a38798fb4f Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 17 Apr 2026 13:26:12 +0200
Subject: [PATCH] [mlir][spirv] Add SPV_EXT_FP8 type support to SPIR-V TOSA ops

Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared
type definitions and extending op constraints for the newly supported
element types.

Also update verifier coverage to reflect the new constraints:
- refresh existing negative tests whose diagnostics now list FP8 types
- add negative tests for SPV_EXT_FP8-specific output, weight,
  accumulator, and cast restrictions

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ie636acc87669a66b53410e7efbd3edafa6ee0da1
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 128 ++++----
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   |  31 +-
 .../SPIRV/IR/tosa-ops-verification.mlir       | 284 +++++++++++++++++-
 3 files changed, 373 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index c873e3069733c..db91e529dc623 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -156,14 +156,20 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
     TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
     TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
     TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+    TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16]>,
+    TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16]>,
     TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
     TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
     TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
+    TypeConstraintImplicationOn<"input", F8E4M3FN, "weight", [F8E4M3FN]>,
+    TypeConstraintImplicationOn<"input", F8E5M2, "weight", [F8E5M2]>,
     TypeImpliesAccType<"input", I8, ["INT32"]>,
     TypeImpliesAccType<"input", I16, ["INT48"]>,
     TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
     TypeImpliesAccType<"input", BF16, ["FP32"]>,
     TypeImpliesAccType<"input", F32, ["FP32"]>,
+    TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
+    TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
     AllElementTypesMatch<["bias", "output"]>,
     AllElementTypesMatch<["input", "input_zp"]>,
     AllElementTypesMatch<["weight", "weight_zp"]>])> {
@@ -249,7 +255,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
   let arguments = (ins
     SPIRV_TensorArmAxisAttr: $axis,
     SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $input
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm: $input
   );
 
   let results = (outs
@@ -277,6 +283,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
   TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
   TypeImpliesAccType<"input", BF16, ["FP32"]>,
   TypeImpliesAccType<"input", F32, ["FP32"]>,
+  TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
+  TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
   AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
   let summary = "Performs average pooling on the input.";
 
@@ -304,13 +312,13 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
     SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
     SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $output_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $output_zp
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $output
   );
 
   let assemblyFormat = [{
@@ -361,11 +369,11 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
     SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -416,11 +424,11 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
     SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm5D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm5D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm5D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -472,11 +480,11 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
     SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -557,6 +565,8 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
   TypeConstraintImplicationOn<"A", BF16, "output", [F32]>,
   TypeConstraintImplicationOn<"A", F16, "output", [F16, F32]>,
   TypeConstraintImplicationOn<"A", F32, "output", [F32]>,
+  TypeConstraintImplicationOn<"A", F8E4M3FN, "output", [F16]>,
+  TypeConstraintImplicationOn<"A", F8E5M2, "output", [F16]>,
   AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
   let summary = "Matrix Multiplication operator.";
 
@@ -579,10 +589,10 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
   }];
 
   let arguments = (ins
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $A,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $B,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $A_zp,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $B_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D: $A,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D: $B,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $A_zp,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $B_zp
   );
 
   let results = (outs
@@ -634,11 +644,11 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
     SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
     SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $output
   );
 
   let assemblyFormat = [{
@@ -734,11 +744,11 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
     SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -2167,11 +2177,11 @@ def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
 
   let arguments = (ins
     SPIRV_TensorArmAxisAttr: $axis,
-    Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm>: $input1
+    Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm>: $input1
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2214,13 +2224,13 @@ def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfEvenLength2To12: $padding,
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $pad_const
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $pad_const
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2267,12 +2277,12 @@ def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfLength1To6: $shape
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2315,11 +2325,11 @@ def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
 
   let arguments = (ins
     SPIRV_TensorArmAxisAttr: $axis,
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2362,13 +2372,13 @@ def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfLength1To6: $start,
     SPIRV_I32_1DTensorArmOfLength1To6: $size
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2416,12 +2426,12 @@ def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfLength1To6: $multiples
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2466,11 +2476,11 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
 
   let arguments = (ins
     SPIRV_I32_1DTensorArmOfLength1To6Attr: $perms,
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2512,12 +2522,12 @@ def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
   }];
 
   let arguments = (ins
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values,
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values,
     SPIRV_I32_TensorArm2D: $indices
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $output
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $output
   );
 
   let assemblyFormat = [{
@@ -2566,13 +2576,13 @@ def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
   }];
 
   let arguments = (ins
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_in,
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values_in,
     SPIRV_I32_TensorArm2D: $indices,
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $input
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $input
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_out
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values_out
   );
 
   let assemblyFormat = [{
@@ -2687,13 +2697,15 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
 
 def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
   AllShapesMatch<["input", "output"]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
+  TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
+  TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16, F8E4M3FN, F8E5M2]>,
   TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
   TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
   TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
   TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
+  TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
+  TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16, F32, BF16]>,
+  TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16, F32, BF16]>]> {
   let summary = "Cast operation.";
 
   let description = [{
@@ -2737,6 +2749,18 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
     | int16   | bf16    |
     | int32   | bf16    |
     | int8    | bf16    |
+    | bf16    | fp8e4m3 |
+    | fp8e4m3 | bf16    |
+    | bf16    | fp8e5m2 |
+    | fp8e5m2 | bf16    |
+    | float16 | fp8e4m3 |
+    | float32 | fp8e4m3 |
+    | fp8e4m3 | float16 |
+    | fp8e4m3 | float32 |
+    | float16 | fp8e5m2 |
+    | float32 | fp8e5m2 |
+    | fp8e5m2 | float16 |
+    | fp8e5m2 | float32 |
 
     References:
       * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
@@ -2750,11 +2774,11 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm: $input
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8_TensorArm: $output
   );
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 6c918aec28845..5704911d7f53d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -22,16 +22,19 @@ def SPIRV_I8OrI16OrI32OrI64 : AnyIntOfWidths<[8, 16, 32, 64]>;
 def SPIRV_I16OrI32 : AnyIntOfWidths<[16, 32]>;
 def SPIRV_I32OrI64 : AnyIntOfWidths<[32, 64]>;
 def SPIRV_F16OrF32OrBF16 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
-def SPIRV_I8OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_F16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR, SPIRV_Float8E4M3EXT, SPIRV_Float8E5M2EXT]>;
 def SPIRV_I8OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16OrFP8]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16OrFP8]>;
 def SPIRV_I32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int32, SPIRV_F16OrF32OrBF16]>;
 def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrFP8]>;
 def SPIRV_I32OrI64OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_F16OrF32OrBF16]>;
 def SPIRV_I32OrI64OrF16OrF32 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_Float16, SPIRV_Float32]>;
 def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32OrI64, SPIRV_F16OrF32OrBF16]>;
-def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
 def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
-def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrFP8]>;
+def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16OrFP8]>;
 def SPIRV_I8OrI32 : AnyTypeOf<[SPIRV_Int8, SPIRV_Int32]>;
 
 def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
@@ -57,23 +60,25 @@ def SPIRV_I32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
 def SPIRV_F32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
 def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [1]>;
 def SPIRV_I8OrI16_TensorArm1D : TensorArmRankOf<[SPIRV_I8OrI16], [1]>;
-def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [3]>;
-def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [3]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [3]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8], [3]>;
 def SPIRV_I32OrI64OrF16OrF32_TensorArm3D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32], [3]>;
 def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [4]>;
-def SPIRV_I8OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [4]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [4]>;
+def SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16OrFP8], [4]>;
 def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [4]>;
 def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16], [4]>;
-def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [5]>;
-def SPIRV_I8OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [5]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [5]>;
+def SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16OrFP8], [5]>;
 def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [5]>;
 
-def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_F16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_F16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_I32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
-def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_I8OrI16OrI32OrI64_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrI64], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_I8OrI16OrI32_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32], [1, 2, 3, 4, 5, 6]>;
@@ -121,12 +126,12 @@ def SPIRV_I32_1DTensorArmOfLength1To6Attr : ConfinedAttr<
   I32ElementsAttr, [SPIRV_DenseElementAttrsWithTensorArmType, Is1DTensorArmAttrOfLength<[1, 2, 3, 4, 5, 6]>]>;
 
 def SPIRV_I8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>;
-def SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrF16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8]>;
 def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32OrF16OrF32OrBF16]>;
 def SPIRV_I8OrI16OrI32OrI64_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32OrI64]>;
 def SPIRV_I8OrI16OrI32_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32]>;
-def SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrF16OrF32OrBF16]>;
-def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16]>;
+def SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrF16OrF32OrBF16OrFP8]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8]>;
 
 // Struct type
 
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index c238de30ff0e6..78ae3b4586004 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -75,6 +75,22 @@ spirv.ARM.Graph @avgpool2d_accumulator_should_be_either_FP32_for_fp32_element_ty
   spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
 }
 
+spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e4m3fn_element_type(%arg0: !spirv.arm.tensor<1x2x2x2xf8E4M3FN>) -> (!spirv.arm.tensor<1x2x2x2xf8E4M3FN>) {
+  %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [1, 1], stride = [1, 1], pad = [0, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x2x2xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x2x2x2xf8E4M3FN>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E4M3FN>
+}
+
+spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e5m2_element_type(%arg0: !spirv.arm.tensor<1x2x2x2xf8E5M2>) -> (!spirv.arm.tensor<1x2x2x2xf8E5M2>) {
+  %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [1, 1], stride = [1, 1], pad = [0, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x2x2xf8E5M2>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x2x2x2xf8E5M2>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E5M2>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Conv2D
 //===----------------------------------------------------------------------===//
@@ -151,6 +167,54 @@ spirv.ARM.Graph @conv2d_accumulator_must_be_either_FP32_for_f32_input_element_ty
   spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
 }
 
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
+}
+
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
+}
+
+spirv.ARM.Graph @conv2d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @conv2d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @conv2d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @conv2d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Conv3D
 //===----------------------------------------------------------------------===//
@@ -227,6 +291,54 @@ spirv.ARM.Graph @conv3d_accumulator_must_be_either_FP32_for_f32_input_element_ty
   spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf32>
 }
 
+spirv.ARM.Graph @conv3d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
+}
+
+spirv.ARM.Graph @conv3d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x4x8xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x4x8xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf32>
+}
+
+spirv.ARM.Graph @conv3d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x4x8xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x4x8xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf32>
+}
+
+spirv.ARM.Graph @conv3d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
+}
+
+spirv.ARM.Graph @conv3d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
+}
+
+spirv.ARM.Graph @conv3d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.DepthwiseConv2D
 //===----------------------------------------------------------------------===//
@@ -303,6 +415,54 @@ spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_either_FP32_for_f32_input_
   spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
 }
 
+spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<1x1x4x2xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E5M2>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<1x1x4x2xf8E5M2>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<1x1x4x2xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.MatMul
 //===----------------------------------------------------------------------===//
@@ -355,6 +515,18 @@ spirv.ARM.Graph @matmul_element_types_must_match_between_input_B_and_B_zero_poin
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xi32>
 }
 
+spirv.ARM.Graph @matmul_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x4x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<1xf8E4M3FN>, %arg3: !spirv.arm.tensor<1xf8E4M3FN>) -> (!spirv.arm.tensor<1x4x4xf32>) {
+  // expected-error @+1 {{op failed to verify that if A has type f8E4M3FN type then output must have a type in [16-bit float]}}
+  %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xf8E4M3FN>, !spirv.arm.tensor<1x4x4xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xf32>
+}
+
+spirv.ARM.Graph @matmul_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x4x4xf8E5M2>, %arg2: !spirv.arm.tensor<1xf8E5M2>, %arg3: !spirv.arm.tensor<1xf8E5M2>) -> (!spirv.arm.tensor<1x4x4xf32>) {
+  // expected-error @+1 {{op failed to verify that if A has type f8E5M2 type then output must have a type in [16-bit float]}}
+  %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xf8E5M2>, !spirv.arm.tensor<1x4x4xf8E5M2>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.MaxPool2D
 //===----------------------------------------------------------------------===//
@@ -441,6 +613,54 @@ spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP32_for_f32_input_
   spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
 }
 
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xbf16>) -> (!spirv.arm.tensor<1x4x4x8xbf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xbf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xbf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xbf16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xbf16>) -> (!spirv.arm.tensor<1x4x4x8xbf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xbf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xbf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xbf16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
+  // expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Clamp
 //===----------------------------------------------------------------------===//
@@ -1866,23 +2086,35 @@ spirv.ARM.Graph @cast_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<
 }
 
 spirv.ARM.Graph @cast_f16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
-  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
   %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xbf16>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
 }
 
 spirv.ARM.Graph @cast_f16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
-  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
   %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xf16>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
 }
 
+spirv.ARM.Graph @cast_f16_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xi1>) {
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
+}
+
 spirv.ARM.Graph @cast_f32_to_f32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf32>) -> (!spirv.arm.tensor<2x3x4xf32>) {
-  // expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type]}}
+  // expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type,f8E4M3FN type,f8E5M2 type]}}
   %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf32> -> !spirv.arm.tensor<2x3x4xf32>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf32>
 }
 
+spirv.ARM.Graph @cast_f32_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf32>) -> (!spirv.arm.tensor<2x3x4xi1>) {
+  // expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type,f8E4M3FN type,f8E5M2 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf32> -> !spirv.arm.tensor<2x3x4xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
+}
+
 spirv.ARM.Graph @cast_i8_to_i8_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x4xi8>) {
   // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [16-bit float,32-bit float,16-bit signless integer,32-bit signless integer,bool,bfloat16 type]}}
   %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x3x4xi8>
@@ -1920,17 +2152,59 @@ spirv.ARM.Graph @cast_bool_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4x
 }
 
 spirv.ARM.Graph @cast_bf16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
-  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
   %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xf16>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
 }
 
 spirv.ARM.Graph @cast_bf16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
-  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
   %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xbf16>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
 }
 
+spirv.ARM.Graph @cast_bf16_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xi1>) {
+  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
+}
+
+spirv.ARM.Graph @cast_f8e4m3fn_to_i8_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E4M3FN>) -> (!spirv.arm.tensor<2x3x4xi8>) {
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E4M3FN> -> !spirv.arm.tensor<2x3x4xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi8>
+}
+
+spirv.ARM.Graph @cast_f8e4m3fn_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E4M3FN>) -> (!spirv.arm.tensor<2x3x4xi1>) {
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E4M3FN> -> !spirv.arm.tensor<2x3x4xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
+}
+
+spirv.ARM.Graph @cast_f8e4m3fn_to_f8e4m3fn_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E4M3FN>) -> (!spirv.arm.tensor<2x3x4xf8E4M3FN>) {
+  // expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E4M3FN> -> !spirv.arm.tensor<2x3x4xf8E4M3FN>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf8E4M3FN>
+}
+
+spirv.ARM.Graph @cast_f8e5m2_to_i16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E5M2>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E5M2> -> !spirv.arm.tensor<2x3x4xi16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @cast_f8e5m2_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E5M2>) -> (!spirv.arm.tensor<2x3x4xi1>) {
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E5M2> -> !spirv.arm.tensor<2x3x4xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
+}
+
+spirv.ARM.Graph @cast_f8e5m2_to_f8e5m2_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E5M2>) -> (!spirv.arm.tensor<2x3x4xf8E5M2>) {
+  // expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E5M2> -> !spirv.arm.tensor<2x3x4xf8E5M2>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf8E5M2>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Rescale
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list