[Mlir-commits] [mlir] [mlir][tosa] Add FP8 support (PR #127730)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 18 19:50:26 PST 2025


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/127730

>From 8f928e66b0cd02b7822f3a7f80073896f16649ac Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 1 Feb 2024 23:44:37 +0000
Subject: [PATCH] [mlir][tosa] Add FP8 support

Add FP8 support to following TOSA operators:

ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
DIM
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER

Also added verifiers as needed to check input/output element types
and renamed inputs of transpose_conv2d and select to match spec.

Signed-off-by: Tai Ly <tai.ly at arm.com>
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I56adfabb2396e38b7ed3479e4fd680b740bdb4e4
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 181 ++++++----
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  31 +-
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  14 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 339 +++++++++++++++++-
 .../Transforms/TosaDecomposeDepthwise.cpp     |   2 +-
 .../Tosa/Transforms/TosaMakeBroadcastable.cpp |   6 +-
 .../TosaToLinalg/tosa-to-linalg-pipeline.mlir |   2 +-
 .../TosaToTensor/tosa-to-tensor.mlir          |   4 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  21 +-
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  |  11 +
 mlir/test/Dialect/Tosa/invalid.mlir           |  34 +-
 mlir/test/Dialect/Tosa/ops.mlir               | 286 ++++++++++++++-
 .../Tosa/tosa-decompose-depthwise.mlir        |   4 +-
 13 files changed, 806 insertions(+), 129 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8947f7a9bd9a1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor: $input,
+    Tosa_Tensor_Extended: $input,
     I32Attr: $axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
@@ -73,7 +73,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
+
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
@@ -83,7 +84,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
@@ -102,7 +103,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
@@ -133,11 +134,12 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor5D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor5D_Extended:$input,
+    TensorRankOf<[Tosa_Weight], [5]>:$weight,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -146,7 +148,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
   );
 
   let results = (outs
-    Tosa_Tensor5D:$output
+    Tosa_Tensor5D_Extended:$output
   );
 
   let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -178,7 +181,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -237,8 +240,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$a,
-    Tosa_Tensor3D:$b,
+    Tosa_Tensor3D_Extended:$a,
+    Tosa_Tensor3D_Extended:$b,
     OptionalAttr<I32Attr>:$a_zp,
     OptionalAttr<I32Attr>:$b_zp
   );
@@ -248,6 +251,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   );
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -264,7 +268,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
 
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
@@ -273,10 +277,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -327,11 +332,12 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
@@ -1190,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$pred,
-    Tosa_Tensor:$on_true,
-    Tosa_Tensor:$on_false
+    Tosa_I1Tensor:$input1,
+    Tosa_Tensor:$input2,
+    Tosa_Tensor:$input3
   );
 
   let results = (outs
@@ -1200,9 +1206,10 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   );
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let assemblyFormat = [{
-    operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
+    operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
     `)` `->` type($output)
   }];
 }
@@ -1518,16 +1525,17 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$input1,
+    Variadic<Tosa_Tensor_Extended>:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1563,14 +1571,14 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
   }];
 
   let arguments = (ins
-    Tosa_RankedTensor:$input1,
+    Tosa_RankedTensor_Extended:$input1,
     Tosa_Shape:$padding,
-    Optional<Tosa_Rank0Tensor>:$pad_const,
+    Optional<Tosa_ScalarTensor_Extended>:$pad_const,
     OptionalAttr<I32Attr>:$input_zp
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_RankedTensor_Extended:$output
   );
 
   let builders = [Tosa_PadOpQuantInfoBuilder,
@@ -1597,12 +1605,12 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
   let hasVerifier = 1;
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$shape
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_RankedTensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1629,12 +1637,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasFolder = 1;
@@ -1656,13 +1664,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$start,
     Tosa_Shape:$size
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasCanonicalizer = 1;
@@ -1681,11 +1689,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$multiples);
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1709,12 +1717,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
-    Tosa_Int32Tensor:$perms
+    Tosa_Tensor_Extended:$input1,
+    Tosa_Int32Or64Tensor:$perms
   );
 
   let results = (
-    outs Tosa_Tensor:$output
+    outs Tosa_Tensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1743,13 +1751,14 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$values,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+    Tosa_Tensor3D_Extended:$values,
+    2DTensorOf<[Tosa_Int32]>:$indices
   );
 
   let results = (outs
-    Tosa_Tensor3D:$output
+    Tosa_Tensor3D_Extended:$output
   );
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1764,14 +1773,15 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$values_in,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
-    Tosa_Tensor3D:$input
+    Tosa_Tensor3D_Extended:$values_in,
+    2DTensorOf<[Tosa_Int32]>:$indices,
+    Tosa_Tensor3D_Extended:$input
   );
 
   let results = (outs
-    Tosa_Tensor3D:$values_out
+    Tosa_Tensor3D_Extended:$values_out
   );
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1828,37 +1838,66 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 
     | Mode                     | Input   | Output  |
     |--------------------------|---------|---------|
-    | signed 8 to bool         | int8    | Boolean |
-    | signed 16 to bool        | int16   | Boolean |
-    | signed 32 to bool        | int32   | Boolean |
-    | bool to 8                | Boolean | int8    |
-    | bool to 16               | Boolean | int16   |
-    | bool to 32               | Boolean | int32   |
-    | signed 8 to signed 16    | int8    | int16   |
-    | signed 8 to signed 32    | int8    | int32   |
-    | signed 16 to signed 8    | int16   | int8    |
-    | signed 16 to signed 32   | int16   | int32   |
-    | signed 32 to signed 8    | int32   | int8    |
-    | signed 32 to signed 16   | int32   | int16   |
-    | float to signed 8        | float   | int8    |
-    | float to signed 16       | float   | int16   |
-    | signed 8 to float        | int8    | float   |
-    | signed 16 to float       | int16   | float   |
-    | float 32 to float 64     | float32 | float64 |
-    | float 64 to float 32     | float64 | float32 |
-  }];
-
-  let arguments = (ins
-    Tosa_Tensor:$input
-  );
-
-  let results = (outs
-    Tosa_Tensor:$output
+    | bool to int 8            | Boolean | int8    |
+    | bool to int 16           | Boolean | int16   |
+    | bool to int 32           | Boolean | int32   |
+    | int 8 to bool            | int8    | Boolean |
+    | int 8 to int 16          | int8    | int16   |
+    | int 8 to int 32          | int8    | int32   |
+    | int 8 to fp16            | int8    | float16 |
+    | int 8 to bf16            | int8    | bf16    |
+    | int 8 to fp32            | int8    | float32 |
+    | int 16 to bool           | int16   | Boolean |
+    | int 16 to int 8          | int16   | int8    |
+    | int 16 to int 32         | int16   | int32   |
+    | int 16 to fp16           | int16   | float16 |
+    | int 16 to bf16           | int16   | bf16    |
+    | int 16 to fp32           | int16   | float32 |
+    | int 32 to bool           | int32   | Boolean |
+    | int 32 to int 8          | int32   | int8    |
+    | int 32 to int 16         | int32   | int16   |
+    | int 32 to fp16           | int32   | float16 |
+    | int 32 to bf16           | int32   | bf16    |
+    | int 32 to fp32           | int32   | float32 |
+    | bf16 to int 8            | bf16    | int8    |
+    | bf16 to int 16           | bf16    | int16   |
+    | bf16 to int 32           | bf16    | int32   |
+    | bf16 to fp8e4m3          | bf16    | fp8e4m3 |
+    | bf16 to fp8e5m2          | bf16    | fp8e5m2 |
+    | bf16 to fp32             | bf16    | float32 |
+    | fp8e4m3 to fp16          | fp8e4m3 | float16 |
+    | fp8e4m3 to bf16          | fp8e4m3 | bf16    |
+    | fp8e4m3 to fp32          | fp8e4m3 | float32 |
+    | fp8e5m2 to fp16          | fp8e5m2 | float16 |
+    | fp8e5m2 to bf16          | fp8e5m2 | bf16    |
+    | fp8e5m2 to fp32          | fp8e5m2 | float32 |
+    | fp16 to int 8            | float16 | int8    |
+    | fp16 to int 16           | float16 | int16   |
+    | fp16 to int 32           | float16 | int32   |
+    | fp16 to fp8e4m3          | float16 | fp8e4m3 |
+    | fp16 to fp8e5m2          | float16 | fp8e5m2 |
+    | fp16 to fp32             | float16 | float32 |
+    | fp32 to int 8            | float32 | int8    |
+    | fp32 to int 16           | float32 | int16   |
+    | fp32 to int 32           | float32 | int32   |
+    | fp32 to fp8e4m3          | float32 | fp8e4m3 |
+    | fp32 to fp8e5m2          | float32 | fp8e5m2 |
+    | fp32 to bf16             | float32 | bf16    |
+    | fp32 to fp16             | float32 | float16 |
+  }];
+
+  let arguments = (ins
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$input
+  );
+
+  let results = (outs
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$output
   );
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1940,7 +1979,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64, Tosa_Int4]>]>:$output
   );
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..2c6e647ae32fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
                                    Tosa_QuantizedType<"int16", [16, 0], 1>,
                                    Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
+def Tosa_F8 : AnyTypeOf<[
+                        F8E4M3FN,
+                        F8E5M2]>;
+
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
 def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                 "number">;
 
+// Add F8 type support to Tosa_AnyNumber
+def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
+                               "number_extended">;
+
 // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
 // tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
-                             Tosa_QuantizedInt, AnyFloat]>;
+                             Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
+
 
 //===----------------------------------------------------------------------===//
 // TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
 def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
 
 // Must be ranked but no further constraints
-def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,9 +156,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
 // Tensor types with constrained ranks.
 //===----------------------------------------------------------------------===//
 
-def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-
+// Scalar tensors: Rank-1 (with only one element)
 def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
 def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
 
 // We include unranked tensors as a supported type for all possible tosa
@@ -155,6 +166,7 @@ def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
 // they should be shape propagate used Tosa's shape inference pass and verified
 // to not include any remaining unranked tensors.
 def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
 
 def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
@@ -162,6 +174,17 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
 def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
+def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
+    "1-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
+    "2-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
+    "3-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
+    "4-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
+    "5-d tosa-conformant tensor extended", "::mlir::TensorType">;
+
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 69b3f6d674167..704f8a82d11fa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
-  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
+  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
   if (!notOp)
     return failure();
   rewriter.modifyOpInPlace(op, [&]() {
     op.getOperation()->setOperands(
-        {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
+        {notOp.getInput1(), op.getInput3(), op.getInput2()});
   });
   return success();
 }
@@ -1118,18 +1118,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
-  if (getOnTrue() == getOnFalse())
-    return getOnTrue();
+  if (getInput2() == getInput3())
+    return getInput2();
 
   auto predicate =
-      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
   if (!predicate)
     return {};
 
   if (!predicate.isSplat())
     return {};
-  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
-                                                         : getOnFalse();
+  return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
+                                                         : getInput3();
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..d88050f6118d7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,15 +217,17 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
 
 template <typename T>
 static LogicalResult verifyConvOp(T op) {
-  // All TOSA conv ops have an input and weight arguments which must be ranked
-  // tensors.
+  // All TOSA conv ops have an input() and weight().
   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+
+  RankedTensorType weightType;
+  weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+
+  // Must be ranked tensor types
   if (!inputType) {
     op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
     return failure();
   }
-
-  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
   if (!weightType) {
     op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
     return failure();
@@ -243,6 +245,9 @@ static LogicalResult verifyConvOp(T op) {
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
+    weightEType = quantType.getStorageType();
+
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
@@ -258,14 +263,22 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
-      isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
+  if (llvm::isa<Float8E5M2Type>(inputEType) ||
+      llvm::isa<Float8E4M3FNType>(inputEType) ||
+      llvm::isa<Float8E5M2Type>(weightEType) ||
+      llvm::isa<Float8E4M3FNType>(weightEType)) {
     if (inputEType != weightEType) {
       op.emitOpError(
           "expect both input and weight to have same element type, got ")
           << inputEType << " and " << weightEType;
       return failure();
     }
+
+    if (!resultEType.isF16()) {
+      op.emitOpError("expect bias and result element type to be f16, got ")
+          << resultEType;
+      return failure();
+    }
   }
 
   bool inputIsFloat = llvm::isa<FloatType>(inputEType);
@@ -459,9 +472,18 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   if (inputETy.isF32() && !accType.isF32())
     return emitOpError("accumulator type for f32 tensor is not f32");
 
+  if ((llvm::isa<Float8E5M2Type>(inputETy) ||
+       llvm::isa<Float8E4M3FNType>(inputETy)) &&
+      !accType.isF16())
+    return emitOpError("accumulator type for f8 tensor is not f16");
+
   if ((inputETy.isF32() && resultETy.isF32()) ||
       (inputETy.isF16() && resultETy.isF16()) ||
       (inputETy.isBF16() && resultETy.isBF16()) ||
+      (llvm::isa<Float8E5M2Type>(inputETy) &&
+       llvm::isa<Float8E5M2Type>(resultETy)) ||
+      (llvm::isa<Float8E4M3FNType>(inputETy) &&
+       llvm::isa<Float8E4M3FNType>(resultETy)) ||
       (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
       (inputETy.isInteger(16) && resultETy.isInteger(16)))
     return success();
@@ -469,6 +491,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   return emitOpError("input/output element types are incompatible.");
 }
 
+LogicalResult tosa::CastOp::verify() {
+  mlir::Type inputETy =
+      llvm::cast<ShapedType>(getInput().getType()).getElementType();
+  if (auto inputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+    inputETy = inputQuantType.getStorageType();
+  }
+  mlir::Type outputETy =
+      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+  if (auto outputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+    outputETy = outputQuantType.getStorageType();
+  }
+
+  // input element type: bool
+  if (inputETy.isInteger(1)) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32)) {
+      return success();
+    }
+  }
+  // input element type: int8
+  if (inputETy.isInteger(8)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int16
+  if (inputETy.isInteger(16)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int32
+  if (inputETy.isInteger(32)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: bf16 or fp16
+  if (inputETy.isBF16() || inputETy.isF16()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: f8e4m3 or f8e5m2
+  if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+      llvm::isa<Float8E5M2Type>(inputETy)) {
+    if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: fp32
+  if (inputETy.isF32()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+        outputETy.isBF16()) {
+      return success();
+    }
+  }
+
+  // following are outside of TOSA Spec
+
+  // allow casting to same type, for quatization/dequantization
+  if (inputETy == outputETy) {
+    return success();
+  }
+
+  // allow casting float to bool, for tosa_to_linalg testing
+  if (inputETy.isF32() && outputETy.isInteger(1)) {
+    return success();
+  }
+
+  // special case for I64
+  if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  // special case for fp64
+  if (inputETy.isF64() || outputETy.isF64()) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  return emitOpError("input/output element types are incompatible: ")
+         << inputETy << " and " << outputETy;
+}
+
 LogicalResult tosa::ClampOp::verify() {
   mlir::Type inputETy =
       llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -849,6 +969,18 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ConcatOp::verify() {
+  // check that each input has same element type as output
+  auto outType = getOutput().getType();
+  for (auto input : getInput1()) {
+    if (verifySameElementTypes(*this, /* inType = */ input.getType(), outType)
+            .failed()) {
+      return failure();
+    }
+  }
+  return success();
+}
+
 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes,
@@ -898,6 +1030,108 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult MatMulOp::verify() {
+  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+  auto resultEType =
+      llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+  // Must be shaped tensor types
+  if (!aType) {
+    emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+    return failure();
+  }
+  if (!bType) {
+    emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+    return failure();
+  }
+
+  auto aElementType = aType.getElementType();
+  auto bElementType = bType.getElementType();
+
+  auto aQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+  auto bQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+  if (aQuantizedEType || bQuantizedEType) {
+    if (!aQuantizedEType || !bQuantizedEType) {
+      emitOpError(
+          "expect operands to be both quantized or both not quantized, got ")
+          << aElementType << " and " << bElementType;
+      return failure();
+    }
+    // both a and b have quantized element types
+    auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+    auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+    if (aQuantWidth != bQuantWidth) {
+      emitOpError("expect quantized operands to have same widths, got ")
+          << aQuantWidth << " and " << bQuantWidth;
+      return failure();
+    }
+
+    if (aQuantWidth != 8 && aQuantWidth != 16) {
+      emitOpError("only support quantized types with width of 8 or 16, got ")
+          << aQuantWidth;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+      emitOpError("expect result element type to be i32, got ") << resultEType;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+      emitOpError("expect result element type to be i48, got ") << resultEType;
+      return failure();
+    }
+
+    return success();
+  }
+
+  // non-quantized element types
+
+  if (aElementType != bElementType) {
+    emitOpError("expect same element type for inputs a and b, got ")
+        << aElementType << " and " << bElementType;
+    return failure();
+  }
+
+  if (llvm::isa<Float8E5M2Type>(aElementType) ||
+      llvm::isa<Float8E4M3FNType>(aElementType)) {
+    if (!resultEType.isF16()) {
+      emitOpError("expect result element type to be f16, got ") << resultEType;
+      return failure();
+    }
+  }
+
+  if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+    emitOpError("expect result element type to be i32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+    emitOpError("expect result element type to be i48, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+    emitOpError("expect result element type to be f16 or f32, got ")
+        << resultEType;
+    return failure();
+  }
+  if (aElementType.isBF16() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF32() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+
+  return success();
+}
+
 LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     PadOp::Adaptor adaptor,
@@ -958,6 +1192,18 @@ LogicalResult tosa::PadOp::verify() {
                          << inputType.getRank() * 2
                          << " (2*rank(shape1)) but got size " << paddingRank;
 
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  if (auto padConst = getPadConst()) {
+    if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+                               /* outType = */ getOutput().getType())
+            .failed()) {
+      return failure();
+    }
+  }
   return success();
 }
 
@@ -1023,17 +1269,20 @@ LogicalResult tosa::SliceOp::verify() {
   if (!inputType)
     return success();
 
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
+
   auto startShapeRank =
       llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
   if (inputType.getRank() != startShapeRank)
-    return emitOpError(
-        "length of start attribute is not equal rank of input shape");
+    return emitOpError("length of start is not equal to rank of input shape");
 
   auto sizeShapeRank =
       llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
   if (inputType.getRank() != sizeShapeRank)
-    return emitOpError(
-        "length of size attribute is not equal rank of input shape");
+    return emitOpError("length of size is not equal to rank of input shape");
 
   return success();
 }
@@ -1238,6 +1487,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TileOp::verify() {
+  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
 
@@ -1322,6 +1576,12 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
+  if (verifySameElementTypes(*this, /* inType = */ inputType,
+                             /* outType = */ outputType)
+          .failed()) {
+    return failure();
+  }
+
   SmallVector<int64_t> shapeValues;
   if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
     // skip following checks if shape is not constant
@@ -1463,6 +1723,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TransposeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   TensorType permType = getPerms().getType();
   TensorType outputType = getOutput().getType();
@@ -1578,6 +1843,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::GatherOp::verify() {
+  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -1649,6 +1919,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ScatterOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed()) {
+    return failure();
+  }
+  return success();
+}
+
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2013,6 +2295,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dOp::verify() {
+  return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
@@ -2317,6 +2604,11 @@ void IfOp::print(OpAsmPrinter &p) {
 LogicalResult ReverseOp::verify() {
   TensorType inputType = getInput1().getType();
   TensorType outputType = getOutput().getType();
+  if (verifySameElementTypes(*this, /* inType = */ inputType,
+                             /* outType = */ outputType)
+          .failed())
+    return failure();
+
   int32_t reverseAxis = getAxis();
 
   if (reverseAxis < 0)
@@ -2343,6 +2635,33 @@ LogicalResult ReverseOp::verify() {
   return success();
 }
 
+LogicalResult tosa::SelectOp::verify() {
+  // verify input2 and input3 have same element type as output
+  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  // verify input1 has element type of bool
+  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+  if (!predicateType) {
+    emitOpError("expect shaped tensor for input1, got ")
+        << getInput1().getType();
+    return failure();
+  }
+  auto predicateElementType = predicateType.getElementType();
+  if (!predicateElementType.isInteger(1)) {
+    emitOpError("expect element type of bool for input1, got ")
+        << predicateElementType;
+    return failure();
+  }
+
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index b26397d0e3ed7..a716c2b7595ca 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -117,7 +117,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
 
       Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
 
-      auto padTy = RankedTensorType::get({}, inputETy);
+      auto padTy = RankedTensorType::get({1}, inputETy);
       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
       Value padVal =
           rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 79afc75fd6c8e..87b2a2695351b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -169,9 +169,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
   LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
                                 PatternRewriter &rewriter) const override {
 
-    Value input1 = tosaOp.getPred();
-    Value input2 = tosaOp.getOnTrue();
-    Value input3 = tosaOp.getOnFalse();
+    Value input1 = tosaOp.getInput1();
+    Value input2 = tosaOp.getInput2();
+    Value input3 = tosaOp.getInput3();
     Value output = tosaOp.getResult();
 
     auto outputType = dyn_cast<RankedTensorType>(output.getType());
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 73da2810abe04..98c7a2c0be845 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -24,7 +24,7 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
 
 // check that tosa verify kick in
 func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
+  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x0x?x9xf32>'}}
     %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index c2eaba4c563d0..6b7f622d3303f 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -542,8 +542,8 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
-  %1 = arith.constant dense<42.0> : tensor<f32>
-  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, !tosa.shape<4>, tensor<f32>)  -> (tensor<4x9xf32>)
+  %1 = arith.constant dense<42.0> : tensor<1xf32>
+  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>)  -> (tensor<4x9xf32>)
   return %2 : tensor<4x9xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 24d572244a9b0..95e1d8b0009d4 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -265,7 +265,8 @@ func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x
   // CHECK: %[[PAD:.+]] = tosa.pad
   // CHECK: return %[[PAD]]
   %shape = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %1 = tosa.pad %arg0, %shape : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
+  %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+  %1 = tosa.pad %arg0, %shape, %pad_const : (tensor<?x?xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
@@ -276,7 +277,8 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
   // CHECK: %[[PAD:.+]] = tosa.pad
   // CHECK: return %[[PAD]]
   %shape = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, !tosa.shape<2>) -> tensor<?xf32>
+  %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+  %0 = tosa.pad %arg0, %shape, %pad_const : (tensor<10xf32>, !tosa.shape<2>, tensor<1xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -284,11 +286,12 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
 
 // CHECK-LABEL: @pad_determine_val_i32
 func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<3> : tensor<1xi32>}
   // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+  %pad_const = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
   %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %1 = tosa.pad %arg0, %0 : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
+  %1 = tosa.pad %arg0, %0, %pad_const : (tensor<?x?xi32>, !tosa.shape<4>, tensor<1xi32>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
 
@@ -296,11 +299,12 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
 
 // CHECK-LABEL: @pad_determine_val_f32
 func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<3.140000e+00> : tensor<1xf32>}
   // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+  %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
+  %1 = tosa.pad %arg0, %0, %pad_const : (tensor<?x?xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
@@ -308,11 +312,12 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
 
 // CHECK-LABEL: @pad_determine_val_quant
 func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<3> : tensor<1xi32>}
   // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+  %pad_const = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
   %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %1 = tosa.pad %arg0, %0 {input_zp = 42 : i32} : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
+  %1 = tosa.pad %arg0, %0, %pad_const {input_zp = 42 : i32} : (tensor<?x?xi32>, !tosa.shape<4>, tensor<1xi32>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index e6fb741df9598..013ac3c547b37 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -136,6 +136,17 @@ func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
 
 // -----
 
+// CHECK-LABEL: @transpose_nofold_quantized_types
+func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
+  %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+  // CHECK: tosa.transpose
+  %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+  return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_add_zero_rhs_f32
 func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index f35c37a1ef70f..7698e28d68e60 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -44,7 +44,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
   %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
+  // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point or f8E4M3FN type or f8E5M2 type values, but got 'tensor<*xi8>'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
@@ -174,8 +174,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 // -----
 
 func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+  // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
@@ -190,10 +189,10 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
   %0 = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
   // expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
-  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<i8>) -> tensor<13x21x3xi8>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
 
@@ -220,7 +219,7 @@ func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
 func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
   %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   %1 = "tosa.const"() {value = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
-  // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<2xf32>'}}
+  // expected-error at +1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number_extended values, but got 'tensor<2xf32>'}}
   %2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
   return
 }
@@ -425,8 +424,7 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
 
 func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
   %1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+  // expected-error at +1 {{'tosa.reshape' op expect input and output to have same element type, got 'f32' and 'i32'}}
   %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32>
   return
 }
@@ -435,7 +433,7 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
 
 func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
   %s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
+  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number_extended values, but got 'tensor<13x0x3xf32>'}}
   %0 = "tosa.reshape"(%arg0, %s) : (tensor<13x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
   return
 }
@@ -444,7 +442,7 @@ func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> ()
 
 func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
   %s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number_extended values, but got 'tensor<?x0x3xf32>'}}
   %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
   return
 }
@@ -516,7 +514,7 @@ func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
 
 func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
   // expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
-  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
   return
 }
 
@@ -524,7 +522,7 @@ func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
 
 func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
   %1 = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number_extended values, but got 'tensor<?x0x3xf32>'}}
   %0 = "tosa.reshape"(%arg0, %1) : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
   return
 }
@@ -540,7 +538,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
 // -----
 
 func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
-  // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
+  // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x29x0x4xf32>'}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
   return %0 : tensor<1x27x27x16xf32>
@@ -549,7 +547,7 @@ func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1:
 // -----
 
 func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
-  // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}}
+  // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x?x0x4xf32>'}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
   return %0 : tensor<1x27x27x16xf32>
@@ -559,7 +557,7 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<
 // -----
 
 func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}}
+  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x0x7x9xf32>'}}
     %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
@@ -568,7 +566,7 @@ func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) ->
 // -----
 
 func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
+  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x0x?x9xf32>'}}
     %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
@@ -625,7 +623,7 @@ func.func @test_slice_invalid_start() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  // expected-error at +1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}}
+  // expected-error at +1 {{'tosa.slice' op length of start is not equal to rank of input shape}}
   %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32>
   return
 }
@@ -636,7 +634,7 @@ func.func @test_slice_invalid_size() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
   %size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // expected-error at +1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}}
+  // expected-error at +1 {{'tosa.slice' op length of size is not equal to rank of input shape}}
   %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index faac8b7c1ff93..85b6b27c7e849 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -582,9 +582,9 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
 // -----
 // CHECK-LABEL: pad_explicit_value
 func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
+  %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
-  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x3xf32>
+  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
   return %1 : tensor<13x21x3xf32>
 }
 
@@ -763,3 +763,285 @@ func.func @test_const_shape() -> !tosa.shape<4> {
   %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   return %cst : !tosa.shape<4>
 }
+
+// F8 support tests
+
+// -----
+// CHECK-LABEL: argmax_f8E5M2
+func.func @test_argmax_f8E5M2(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32>
+  return %0 : tensor<12x16xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f8E5M2
+func.func @test_avg_pool2d_f8E5M2(%arg0: tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2> {
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2>
+  return %0 : tensor<1x7x7x9xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: conv2d_f8E5M2
+func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: conv3d_f8E5M2
+func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+  %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>) -> tensor<1x4x8x21x34xf16>
+  return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E5M2
+func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_f8E5M2
+func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16>
+  return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f8E5M2
+func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2>
+  return %0 : tensor<1x32x32x8xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: transpose_conv2d_f8E5M2
+func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>) -> tensor<1x32x32x16xf16>
+  return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+// CHECK-LABEL: const_f8E5M2
+func.func @test_const_f8E5M2(%arg0 : index) -> tensor<4xf8E5M2> {
+    %0 = "tosa.const"() {value = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E5M2>} : () -> tensor<4xf8E5M2>
+    return %0 : tensor<4xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: cast_f8E5M2
+func.func @test_cast_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16>
+  return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: concat_f8E5M2
+func.func @test_concat_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2> {
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>, tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2>
+  return %0 : tensor<26x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: pad_f8E5M2
+func.func @test_pad_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+  %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %cst = "tosa.const"() { value = dense<-0.0> : tensor<1xf8E5M2> } : () -> tensor<1xf8E5M2>
+  %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E5M2>, !tosa.shape<6>, tensor<1xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+  return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: reshape_f8E5M2
+func.func @test_reshape_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<1x819xf8E5M2> {
+  %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<2>) -> tensor<1x819xf8E5M2>
+  return %0 : tensor<1x819xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: reverse_f8E5M2
+func.func @test_reverse_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+  return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: slice_f8E5M2
+func.func @test_slice_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<4x11x1xf8E5M2> {
+  %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E5M2>
+  return %2 : tensor<4x11x1xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: tile_f8E5M2
+func.func @test_tile_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<39x21x6xf8E5M2> {
+  %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E5M2>, !tosa.shape<3>) -> tensor<39x21x6xf8E5M2>
+  return %0 : tensor<39x21x6xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: transpose_f8E5M2
+func.func @test_transpose_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2> {
+  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf8E5M2>, tensor<3xi32>) -> tensor<3x13x21xf8E5M2>
+  return %1 : tensor<3x13x21xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: gather_f8E5M2
+func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2>
+  return %0 : tensor<13x26x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: scatter_f8E5M2
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+  return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: argmax_f8E4M3FN
+func.func @test_argmax_f8E4M3FN(%arg0: tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32>
+  return %0 : tensor<12x16xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f8E4M3FN
+func.func @test_avg_pool2d_f8E4M3FN(%arg0: tensor<1x7x7x9xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN> {
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN>
+  return %0 : tensor<1x7x7x9xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: conv2d_f8E4M3FN
+func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8x1x1x4xf8E4M3FN>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E4M3FN>, tensor<8x1x1x4xf8E4M3FN>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: conv3d_f8E4M3FN
+func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+  %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<34xf16>) -> tensor<1x4x8x21x34xf16>
+  return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E4M3FN
+func.func @test_depthwise_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<1x1x4x2xf8E4M3FN>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E4M3FN>, tensor<1x1x4x2xf8E4M3FN>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: matmul_f8E4M3FN
+func.func @test_matmul_f8E4M3FN(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16>
+  return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f8E4M3FN
+func.func @test_max_pool2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN>
+  return %0 : tensor<1x32x32x8xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: transpose_conv2d_f8E4M3FN
+func.func @test_transpose_conv2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>, %arg1: tensor<16x1x1x8xf8E4M3FN>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>, tensor<16x1x1x8xf8E4M3FN>, tensor<16xf16>) -> tensor<1x32x32x16xf16>
+  return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+// CHECK-LABEL: const_f8E4M3FN
+func.func @test_const_f8E4M3FN(%arg0 : index) -> tensor<4xf8E4M3FN> {
+    %0 = "tosa.const"() {value = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E4M3FN>} : () -> tensor<4xf8E4M3FN>
+    return %0 : tensor<4xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: cast_f8E4M3FN
+func.func @test_cast_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: concat_f8E4M3FN
+func.func @test_concat_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN> {
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>, tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN>
+  return %0 : tensor<26x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: pad_f8E4M3FN
+func.func @test_pad_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %cst = "tosa.const"() { value = dense<-0.0> : tensor<1xf8E4M3FN> } : () -> tensor<1xf8E4M3FN>
+  %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<1xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+  return %0 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: reshape_f8E4M3FN
+func.func @test_reshape_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1x819xf8E4M3FN> {
+  %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<2>) -> tensor<1x819xf8E4M3FN>
+  return %0 : tensor<1x819xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: reverse_f8E4M3FN
+func.func @test_reverse_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+  return %0 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: slice_f8E4M3FN
+func.func @test_slice_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<4x11x1xf8E4M3FN> {
+  %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E4M3FN>
+  return %2 : tensor<4x11x1xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: tile_f8E4M3FN
+func.func @test_tile_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<39x21x6xf8E4M3FN> {
+  %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>) -> tensor<39x21x6xf8E4M3FN>
+  return %0 : tensor<39x21x6xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: transpose_f8E4M3FN
+func.func @test_transpose_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN> {
+  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf8E4M3FN>, tensor<3xi32>) -> tensor<3x13x21xf8E4M3FN>
+  return %1 : tensor<3x13x21xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: gather_f8E4M3FN
+func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN>
+  return %0 : tensor<13x26x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: scatter_f8E4M3FN
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+  return %0 : tensor<13x21x3xf8E4M3FN>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 6562a7c2ab55c..3a03f13522db2 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -59,13 +59,13 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
 func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
   // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>}
   // CHECK-DAG: %[[pad:.+]] = tosa.const_shape  {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
-  // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}
   // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>}
   // CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 6]> : tensor<4xindex>}
   // CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>}
   // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
   // CHECK: %[[reIn:.+]] = tosa.reshape %arg0, %[[CONST0]]
-  // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+  // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<1xf32>) -> tensor<4x12x12x2x1xf32>
   // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1, %[[CONST3]]
   // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]], %[[SHIFT]]
   // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]], %[[CONST4]]



More information about the Mlir-commits mailing list