[Mlir-commits] [mlir] [mlir][tosa] Add FP8 support (PR #127730)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 18 19:50:26 PST 2025
https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/127730
>From 8f928e66b0cd02b7822f3a7f80073896f16649ac Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 1 Feb 2024 23:44:37 +0000
Subject: [PATCH] [mlir][tosa] Add FP8 support
Add FP8 support to following TOSA operators:
ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
DIM
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER
Also added verifiers as needed to check input/output element types
and renamed inputs of transpose_conv2d and select to match spec.
Signed-off-by: Tai Ly <tai.ly at arm.com>
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I56adfabb2396e38b7ed3479e4fd680b740bdb4e4
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 181 ++++++----
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 31 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 14 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 339 +++++++++++++++++-
.../Transforms/TosaDecomposeDepthwise.cpp | 2 +-
.../Tosa/Transforms/TosaMakeBroadcastable.cpp | 6 +-
.../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 2 +-
.../TosaToTensor/tosa-to-tensor.mlir | 4 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 21 +-
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 11 +
mlir/test/Dialect/Tosa/invalid.mlir | 34 +-
mlir/test/Dialect/Tosa/ops.mlir | 286 ++++++++++++++-
.../Tosa/tosa-decompose-depthwise.mlir | 4 +-
13 files changed, 806 insertions(+), 129 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8947f7a9bd9a1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
}];
let arguments = (ins
- Tosa_Tensor: $input,
+ Tosa_Tensor_Extended: $input,
I32Attr: $axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
@@ -73,7 +73,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
+
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
@@ -83,7 +84,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
@@ -102,7 +103,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
@@ -133,11 +134,12 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
}];
let arguments = (ins
- Tosa_Tensor5D:$input,
- TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor5D_Extended:$input,
+ TensorRankOf<[Tosa_Weight], [5]>:$weight,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
@@ -146,7 +148,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
);
let results = (outs
- Tosa_Tensor5D:$output
+ Tosa_Tensor5D_Extended:$output
);
let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -178,7 +181,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -237,8 +240,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$a,
- Tosa_Tensor3D:$b,
+ Tosa_Tensor3D_Extended:$a,
+ Tosa_Tensor3D_Extended:$b,
OptionalAttr<I32Attr>:$a_zp,
OptionalAttr<I32Attr>:$b_zp
);
@@ -248,6 +251,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
);
let builders = [Tosa_MatMulOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -264,7 +268,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
@@ -273,10 +277,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -327,11 +332,12 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
@@ -1190,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];
let arguments = (ins
- Tosa_I1Tensor:$pred,
- Tosa_Tensor:$on_true,
- Tosa_Tensor:$on_false
+ Tosa_I1Tensor:$input1,
+ Tosa_Tensor:$input2,
+ Tosa_Tensor:$input3
);
let results = (outs
@@ -1200,9 +1206,10 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
);
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
- operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
+ operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
`)` `->` type($output)
}];
}
@@ -1518,16 +1525,17 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$input1,
+ Variadic<Tosa_Tensor_Extended>:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1563,14 +1571,14 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
}];
let arguments = (ins
- Tosa_RankedTensor:$input1,
+ Tosa_RankedTensor_Extended:$input1,
Tosa_Shape:$padding,
- Optional<Tosa_Rank0Tensor>:$pad_const,
+ Optional<Tosa_ScalarTensor_Extended>:$pad_const,
OptionalAttr<I32Attr>:$input_zp
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_RankedTensor_Extended:$output
);
let builders = [Tosa_PadOpQuantInfoBuilder,
@@ -1597,12 +1605,12 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
let hasVerifier = 1;
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$shape
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_RankedTensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1629,12 +1637,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasFolder = 1;
@@ -1656,13 +1664,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$start,
Tosa_Shape:$size
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasCanonicalizer = 1;
@@ -1681,11 +1689,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$multiples);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1709,12 +1717,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
}];
let arguments = (ins
- Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ Tosa_Tensor_Extended:$input1,
+ Tosa_Int32Or64Tensor:$perms
);
let results = (
- outs Tosa_Tensor:$output
+ outs Tosa_Tensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1743,13 +1751,14 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$values,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+ Tosa_Tensor3D_Extended:$values,
+ 2DTensorOf<[Tosa_Int32]>:$indices
);
let results = (outs
- Tosa_Tensor3D:$output
+ Tosa_Tensor3D_Extended:$output
);
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1764,14 +1773,15 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$values_in,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
- Tosa_Tensor3D:$input
+ Tosa_Tensor3D_Extended:$values_in,
+ 2DTensorOf<[Tosa_Int32]>:$indices,
+ Tosa_Tensor3D_Extended:$input
);
let results = (outs
- Tosa_Tensor3D:$values_out
+ Tosa_Tensor3D_Extended:$values_out
);
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1828,37 +1838,66 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| Mode | Input | Output |
|--------------------------|---------|---------|
- | signed 8 to bool | int8 | Boolean |
- | signed 16 to bool | int16 | Boolean |
- | signed 32 to bool | int32 | Boolean |
- | bool to 8 | Boolean | int8 |
- | bool to 16 | Boolean | int16 |
- | bool to 32 | Boolean | int32 |
- | signed 8 to signed 16 | int8 | int16 |
- | signed 8 to signed 32 | int8 | int32 |
- | signed 16 to signed 8 | int16 | int8 |
- | signed 16 to signed 32 | int16 | int32 |
- | signed 32 to signed 8 | int32 | int8 |
- | signed 32 to signed 16 | int32 | int16 |
- | float to signed 8 | float | int8 |
- | float to signed 16 | float | int16 |
- | signed 8 to float | int8 | float |
- | signed 16 to float | int16 | float |
- | float 32 to float 64 | float32 | float64 |
- | float 64 to float 32 | float64 | float32 |
- }];
-
- let arguments = (ins
- Tosa_Tensor:$input
- );
-
- let results = (outs
- Tosa_Tensor:$output
+ | bool to int 8 | Boolean | int8 |
+ | bool to int 16 | Boolean | int16 |
+ | bool to int 32 | Boolean | int32 |
+ | int 8 to bool | int8 | Boolean |
+ | int 8 to int 16 | int8 | int16 |
+ | int 8 to int 32 | int8 | int32 |
+ | int 8 to fp16 | int8 | float16 |
+ | int 8 to bf16 | int8 | bf16 |
+ | int 8 to fp32 | int8 | float32 |
+ | int 16 to bool | int16 | Boolean |
+ | int 16 to int 8 | int16 | int8 |
+ | int 16 to int 32 | int16 | int32 |
+ | int 16 to fp16 | int16 | float16 |
+ | int 16 to bf16 | int16 | bf16 |
+ | int 16 to fp32 | int16 | float32 |
+ | int 32 to bool | int32 | Boolean |
+ | int 32 to int 8 | int32 | int8 |
+ | int 32 to int 16 | int32 | int16 |
+ | int 32 to fp16 | int32 | float16 |
+ | int 32 to bf16 | int32 | bf16 |
+ | int 32 to fp32 | int32 | float32 |
+ | bf16 to int 8 | bf16 | int8 |
+ | bf16 to int 16 | bf16 | int16 |
+ | bf16 to int 32 | bf16 | int32 |
+ | bf16 to fp8e4m3 | bf16 | fp8e4m3 |
+ | bf16 to fp8e5m2 | bf16 | fp8e5m2 |
+ | bf16 to fp32 | bf16 | float32 |
+ | fp8e4m3 to fp16 | fp8e4m3 | float16 |
+ | fp8e4m3 to bf16 | fp8e4m3 | bf16 |
+ | fp8e4m3 to fp32 | fp8e4m3 | float32 |
+ | fp8e5m2 to fp16 | fp8e5m2 | float16 |
+ | fp8e5m2 to bf16 | fp8e5m2 | bf16 |
+ | fp8e5m2 to fp32 | fp8e5m2 | float32 |
+ | fp16 to int 8 | float16 | int8 |
+ | fp16 to int 16 | float16 | int16 |
+ | fp16 to int 32 | float16 | int32 |
+ | fp16 to fp8e4m3 | float16 | fp8e4m3 |
+ | fp16 to fp8e5m2 | float16 | fp8e5m2 |
+ | fp16 to fp32 | float16 | float32 |
+ | fp32 to int 8 | float32 | int8 |
+ | fp32 to int 16 | float32 | int16 |
+ | fp32 to int 32 | float32 | int32 |
+ | fp32 to fp8e4m3 | float32 | fp8e4m3 |
+ | fp32 to fp8e5m2 | float32 | fp8e5m2 |
+ | fp32 to bf16 | float32 | bf16 |
+ | fp32 to fp16 | float32 | float16 |
+ }];
+
+ let arguments = (ins
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$input
+ );
+
+ let results = (outs
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$output
);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1940,7 +1979,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64, Tosa_Int4]>]>:$output
);
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..2c6e647ae32fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_F8 : AnyTypeOf<[
+ F8E4M3FN,
+ F8E5M2]>;
+
//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;
+// Add F8 type support to Tosa_AnyNumber
+def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
+ "number_extended">;
+
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
- Tosa_QuantizedInt, AnyFloat]>;
+ Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
+
//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
// Must be ranked but no further constraints
-def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,9 +156,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//
-def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-
+// Scalar tensors: Rank-1 (with only one element)
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
// We include unranked tensors as a supported type for all possible tosa
@@ -155,6 +166,7 @@ def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
// they should be shape propagate used Tosa's shape inference pass and verified
// to not include any remaining unranked tensors.
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
@@ -162,6 +174,17 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
+ "1-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
+ "2-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
+ "3-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
+ "4-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
+ "5-d tosa-conformant tensor extended", "::mlir::TensorType">;
+
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 69b3f6d674167..704f8a82d11fa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
- auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
+ auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
rewriter.modifyOpInPlace(op, [&]() {
op.getOperation()->setOperands(
- {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
+ {notOp.getInput1(), op.getInput3(), op.getInput2()});
});
return success();
}
@@ -1118,18 +1118,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
- if (getOnTrue() == getOnFalse())
- return getOnTrue();
+ if (getInput2() == getInput3())
+ return getInput2();
auto predicate =
- llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+ llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
if (!predicate)
return {};
if (!predicate.isSplat())
return {};
- return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
- : getOnFalse();
+ return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
+ : getInput3();
}
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..d88050f6118d7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,15 +217,17 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
template <typename T>
static LogicalResult verifyConvOp(T op) {
- // All TOSA conv ops have an input and weight arguments which must be ranked
- // tensors.
+ // All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+
+ RankedTensorType weightType;
+ weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+
+ // Must be ranked tensor types
if (!inputType) {
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
return failure();
}
-
- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
if (!weightType) {
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
return failure();
@@ -243,6 +245,9 @@ static LogicalResult verifyConvOp(T op) {
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
inputEType = quantType.getStorageType();
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
+ weightEType = quantType.getStorageType();
+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
biasEType = quantType.getStorageType();
@@ -258,14 +263,22 @@ static LogicalResult verifyConvOp(T op) {
return failure();
}
- if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
- isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
+ if (llvm::isa<Float8E5M2Type>(inputEType) ||
+ llvm::isa<Float8E4M3FNType>(inputEType) ||
+ llvm::isa<Float8E5M2Type>(weightEType) ||
+ llvm::isa<Float8E4M3FNType>(weightEType)) {
if (inputEType != weightEType) {
op.emitOpError(
"expect both input and weight to have same element type, got ")
<< inputEType << " and " << weightEType;
return failure();
}
+
+ if (!resultEType.isF16()) {
+ op.emitOpError("expect bias and result element type to be f16, got ")
+ << resultEType;
+ return failure();
+ }
}
bool inputIsFloat = llvm::isa<FloatType>(inputEType);
@@ -459,9 +472,18 @@ LogicalResult tosa::AvgPool2dOp::verify() {
if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");
+ if ((llvm::isa<Float8E5M2Type>(inputETy) ||
+ llvm::isa<Float8E4M3FNType>(inputETy)) &&
+ !accType.isF16())
+ return emitOpError("accumulator type for f8 tensor is not f16");
+
if ((inputETy.isF32() && resultETy.isF32()) ||
(inputETy.isF16() && resultETy.isF16()) ||
(inputETy.isBF16() && resultETy.isBF16()) ||
+ (llvm::isa<Float8E5M2Type>(inputETy) &&
+ llvm::isa<Float8E5M2Type>(resultETy)) ||
+ (llvm::isa<Float8E4M3FNType>(inputETy) &&
+ llvm::isa<Float8E4M3FNType>(resultETy)) ||
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
(inputETy.isInteger(16) && resultETy.isInteger(16)))
return success();
@@ -469,6 +491,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
return emitOpError("input/output element types are incompatible.");
}
+LogicalResult tosa::CastOp::verify() {
+ mlir::Type inputETy =
+ llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ if (auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+ inputETy = inputQuantType.getStorageType();
+ }
+ mlir::Type outputETy =
+ llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+ if (auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+ outputETy = outputQuantType.getStorageType();
+ }
+
+ // input element type: bool
+ if (inputETy.isInteger(1)) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32)) {
+ return success();
+ }
+ }
+ // input element type: int8
+ if (inputETy.isInteger(8)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int16
+ if (inputETy.isInteger(16)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int32
+ if (inputETy.isInteger(32)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: bf16 or fp16
+ if (inputETy.isBF16() || inputETy.isF16()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: f8e4m3 or f8e5m2
+ if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+ llvm::isa<Float8E5M2Type>(inputETy)) {
+ if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: fp32
+ if (inputETy.isF32()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+ outputETy.isBF16()) {
+ return success();
+ }
+ }
+
+ // following are outside of TOSA Spec
+
+ // allow casting to same type, for quatization/dequantization
+ if (inputETy == outputETy) {
+ return success();
+ }
+
+ // allow casting float to bool, for tosa_to_linalg testing
+ if (inputETy.isF32() && outputETy.isInteger(1)) {
+ return success();
+ }
+
+ // special case for I64
+ if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ // special case for fp64
+ if (inputETy.isF64() || outputETy.isF64()) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ return emitOpError("input/output element types are incompatible: ")
+ << inputETy << " and " << outputETy;
+}
+
LogicalResult tosa::ClampOp::verify() {
mlir::Type inputETy =
llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -849,6 +969,18 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ConcatOp::verify() {
+ // check that each input has same element type as output
+ auto outType = getOutput().getType();
+ for (auto input : getInput1()) {
+ if (verifySameElementTypes(*this, /* inType = */ input.getType(), outType)
+ .failed()) {
+ return failure();
+ }
+ }
+ return success();
+}
+
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
@@ -898,6 +1030,108 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
return success();
}
+LogicalResult MatMulOp::verify() {
+ auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+ auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+ auto resultEType =
+ llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+ // Must be shaped tensor types
+ if (!aType) {
+ emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+ return failure();
+ }
+ if (!bType) {
+ emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+ return failure();
+ }
+
+ auto aElementType = aType.getElementType();
+ auto bElementType = bType.getElementType();
+
+ auto aQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+ auto bQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+ if (aQuantizedEType || bQuantizedEType) {
+ if (!aQuantizedEType || !bQuantizedEType) {
+ emitOpError(
+ "expect operands to be both quantized or both not quantized, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ // both a and b have quantized element types
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+ if (aQuantWidth != bQuantWidth) {
+ emitOpError("expect quantized operands to have same widths, got ")
+ << aQuantWidth << " and " << bQuantWidth;
+ return failure();
+ }
+
+ if (aQuantWidth != 8 && aQuantWidth != 16) {
+ emitOpError("only support quantized types with width of 8 or 16, got ")
+ << aQuantWidth;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+ }
+
+ // non-quantized element types
+
+ if (aElementType != bElementType) {
+ emitOpError("expect same element type for inputs a and b, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+
+ if (llvm::isa<Float8E5M2Type>(aElementType) ||
+ llvm::isa<Float8E4M3FNType>(aElementType)) {
+ if (!resultEType.isF16()) {
+ emitOpError("expect result element type to be f16, got ") << resultEType;
+ return failure();
+ }
+ }
+
+ if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+ emitOpError("expect result element type to be f16 or f32, got ")
+ << resultEType;
+ return failure();
+ }
+ if (aElementType.isBF16() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF32() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -958,6 +1192,18 @@ LogicalResult tosa::PadOp::verify() {
<< inputType.getRank() * 2
<< " (2*rank(shape1)) but got size " << paddingRank;
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ if (auto padConst = getPadConst()) {
+ if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ }
return success();
}
@@ -1023,17 +1269,20 @@ LogicalResult tosa::SliceOp::verify() {
if (!inputType)
return success();
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
+
auto startShapeRank =
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
if (inputType.getRank() != startShapeRank)
- return emitOpError(
- "length of start attribute is not equal rank of input shape");
+ return emitOpError("length of start is not equal to rank of input shape");
auto sizeShapeRank =
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
if (inputType.getRank() != sizeShapeRank)
- return emitOpError(
- "length of size attribute is not equal rank of input shape");
+ return emitOpError("length of size is not equal to rank of input shape");
return success();
}
@@ -1238,6 +1487,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
}
LogicalResult tosa::TileOp::verify() {
+ if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
@@ -1322,6 +1576,12 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
+ if (verifySameElementTypes(*this, /* inType = */ inputType,
+ /* outType = */ outputType)
+ .failed()) {
+ return failure();
+ }
+
SmallVector<int64_t> shapeValues;
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
// skip following checks if shape is not constant
@@ -1463,6 +1723,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
LogicalResult tosa::TransposeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
@@ -1578,6 +1843,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::GatherOp::verify() {
+ return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
@@ -1649,6 +1919,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ScatterOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed()) {
+ return failure();
+ }
+ return success();
+}
+
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2013,6 +2295,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}
+LogicalResult MaxPool2dOp::verify() {
+ return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2317,6 +2604,11 @@ void IfOp::print(OpAsmPrinter &p) {
LogicalResult ReverseOp::verify() {
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
+ if (verifySameElementTypes(*this, /* inType = */ inputType,
+ /* outType = */ outputType)
+ .failed())
+ return failure();
+
int32_t reverseAxis = getAxis();
if (reverseAxis < 0)
@@ -2343,6 +2635,33 @@ LogicalResult ReverseOp::verify() {
return success();
}
+LogicalResult tosa::SelectOp::verify() {
+ // verify input2 and input3 have same element type as output
+ if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+ /* outType = */ getOutput().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ // verify input1 has element type of bool
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+ if (!predicateType) {
+ emitOpError("expect shaped tensor for input1, got ")
+ << getInput1().getType();
+ return failure();
+ }
+ auto predicateElementType = predicateType.getElementType();
+ if (!predicateElementType.isInteger(1)) {
+ emitOpError("expect element type of bool for input1, got ")
+ << predicateElementType;
+ return failure();
+ }
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index b26397d0e3ed7..a716c2b7595ca 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -117,7 +117,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
- auto padTy = RankedTensorType::get({}, inputETy);
+ auto padTy = RankedTensorType::get({1}, inputETy);
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
Value padVal =
rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 79afc75fd6c8e..87b2a2695351b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -169,9 +169,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
PatternRewriter &rewriter) const override {
- Value input1 = tosaOp.getPred();
- Value input2 = tosaOp.getOnTrue();
- Value input3 = tosaOp.getOnFalse();
+ Value input1 = tosaOp.getInput1();
+ Value input2 = tosaOp.getInput2();
+ Value input3 = tosaOp.getInput3();
Value output = tosaOp.getResult();
auto outputType = dyn_cast<RankedTensorType>(output.getType());
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 73da2810abe04..98c7a2c0be845 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -24,7 +24,7 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
// check that tosa verify kick in
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
- // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
+ // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x0x?x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index c2eaba4c563d0..6b7f622d3303f 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -542,8 +542,8 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
- %1 = arith.constant dense<42.0> : tensor<f32>
- %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<f32>) -> (tensor<4x9xf32>)
+ %1 = arith.constant dense<42.0> : tensor<1xf32>
+ %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<4x9xf32>)
return %2 : tensor<4x9xf32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 24d572244a9b0..95e1d8b0009d4 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -265,7 +265,8 @@ func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
%shape = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- %1 = tosa.pad %arg0, %shape : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %1 = tosa.pad %arg0, %shape, %pad_const : (tensor<?x?xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -276,7 +277,8 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
%shape = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
- %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, !tosa.shape<2>) -> tensor<?xf32>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %0 = tosa.pad %arg0, %shape, %pad_const : (tensor<10xf32>, !tosa.shape<2>, tensor<1xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -284,11 +286,12 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
// CHECK-LABEL: @pad_determine_val_i32
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
- // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<3> : tensor<1xi32>}
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+ %pad_const = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
+ %1 = tosa.pad %arg0, %0, %pad_const : (tensor<?x?xi32>, !tosa.shape<4>, tensor<1xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
@@ -296,11 +299,12 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
// CHECK-LABEL: @pad_determine_val_f32
func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
- // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<3.140000e+00> : tensor<1xf32>}
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
+ %1 = tosa.pad %arg0, %0, %pad_const : (tensor<?x?xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -308,11 +312,12 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
// CHECK-LABEL: @pad_determine_val_quant
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
- // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<3> : tensor<1xi32>}
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+ %pad_const = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- %1 = tosa.pad %arg0, %0 {input_zp = 42 : i32} : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
+ %1 = tosa.pad %arg0, %0, %pad_const {input_zp = 42 : i32} : (tensor<?x?xi32>, !tosa.shape<4>, tensor<1xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index e6fb741df9598..013ac3c547b37 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -136,6 +136,17 @@ func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
// -----
+// CHECK-LABEL: @transpose_nofold_quantized_types
+func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
+ %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+ // CHECK: tosa.transpose
+ %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+ return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+}
+
+// -----
+
// CHECK-LABEL: @fold_add_zero_rhs_f32
func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index f35c37a1ef70f..7698e28d68e60 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -44,7 +44,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
+ // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point or f8E4M3FN type or f8E5M2 type values, but got 'tensor<*xi8>'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
@@ -174,8 +174,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+ // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}
@@ -190,10 +189,10 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
- %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<i8>) -> tensor<13x21x3xi8>
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
@@ -220,7 +219,7 @@ func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%1 = "tosa.const"() {value = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
- // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<2xf32>'}}
+ // expected-error at +1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number_extended values, but got 'tensor<2xf32>'}}
%2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
return
}
@@ -425,8 +424,7 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
%1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+ // expected-error at +1 {{'tosa.reshape' op expect input and output to have same element type, got 'f32' and 'i32'}}
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32>
return
}
@@ -435,7 +433,7 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
%s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
+ // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number_extended values, but got 'tensor<13x0x3xf32>'}}
%0 = "tosa.reshape"(%arg0, %s) : (tensor<13x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
return
}
@@ -444,7 +442,7 @@ func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> ()
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
%s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+ // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number_extended values, but got 'tensor<?x0x3xf32>'}}
%0 = "tosa.reshape"(%arg0, %s) : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
return
}
@@ -516,7 +514,7 @@ func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
- %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+ %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
return
}
@@ -524,7 +522,7 @@ func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
%1 = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+ // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number_extended values, but got 'tensor<?x0x3xf32>'}}
%0 = "tosa.reshape"(%arg0, %1) : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
return
}
@@ -540,7 +538,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
// -----
func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
- // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
+ // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x29x0x4xf32>'}}
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
return %0 : tensor<1x27x27x16xf32>
@@ -549,7 +547,7 @@ func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1:
// -----
func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
- // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}}
+ // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x?x0x4xf32>'}}
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
return %0 : tensor<1x27x27x16xf32>
@@ -559,7 +557,7 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<
// -----
func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> {
- // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}}
+ // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x0x7x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
@@ -568,7 +566,7 @@ func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) ->
// -----
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
- // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
+ // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor extended, but got 'tensor<1x0x?x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
@@ -625,7 +623,7 @@ func.func @test_slice_invalid_start() {
%0 = tensor.empty() : tensor<4x31x31xf32>
%start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // expected-error at +1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}}
+ // expected-error at +1 {{'tosa.slice' op length of start is not equal to rank of input shape}}
%3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32>
return
}
@@ -636,7 +634,7 @@ func.func @test_slice_invalid_size() {
%0 = tensor.empty() : tensor<4x31x31xf32>
%start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
- // expected-error at +1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}}
+ // expected-error at +1 {{'tosa.slice' op length of size is not equal to rank of input shape}}
%3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32>
return
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index faac8b7c1ff93..85b6b27c7e849 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -582,9 +582,9 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// -----
// CHECK-LABEL: pad_explicit_value
func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
- %0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
+ %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
%padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
- %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x3xf32>
+ %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
}
@@ -763,3 +763,285 @@ func.func @test_const_shape() -> !tosa.shape<4> {
%cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
return %cst : !tosa.shape<4>
}
+
+// F8 support tests
+
+// -----
+// CHECK-LABEL: argmax_f8E5M2
+func.func @test_argmax_f8E5M2(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32> {
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32>
+ return %0 : tensor<12x16xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f8E5M2
+func.func @test_avg_pool2d_f8E5M2(%arg0: tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2> {
+ %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2>
+ return %0 : tensor<1x7x7x9xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: conv2d_f8E5M2
+func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: conv3d_f8E5M2
+func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>) -> tensor<1x4x8x21x34xf16>
+ return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E5M2
+func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_f8E5M2
+func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f8E5M2
+func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2>
+ return %0 : tensor<1x32x32x8xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: transpose_conv2d_f8E5M2
+func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>) -> tensor<1x32x32x16xf16>
+ return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+// CHECK-LABEL: const_f8E5M2
+func.func @test_const_f8E5M2(%arg0 : index) -> tensor<4xf8E5M2> {
+ %0 = "tosa.const"() {value = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E5M2>} : () -> tensor<4xf8E5M2>
+ return %0 : tensor<4xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: cast_f8E5M2
+func.func @test_cast_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16>
+ return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: concat_f8E5M2
+func.func @test_concat_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2> {
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>, tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2>
+ return %0 : tensor<26x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: pad_f8E5M2
+func.func @test_pad_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+ %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %cst = "tosa.const"() { value = dense<-0.0> : tensor<1xf8E5M2> } : () -> tensor<1xf8E5M2>
+ %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E5M2>, !tosa.shape<6>, tensor<1xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+ return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: reshape_f8E5M2
+func.func @test_reshape_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<1x819xf8E5M2> {
+ %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<2>) -> tensor<1x819xf8E5M2>
+ return %0 : tensor<1x819xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: reverse_f8E5M2
+func.func @test_reverse_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+ return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: slice_f8E5M2
+func.func @test_slice_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<4x11x1xf8E5M2> {
+ %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E5M2>
+ return %2 : tensor<4x11x1xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: tile_f8E5M2
+func.func @test_tile_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<39x21x6xf8E5M2> {
+ %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E5M2>, !tosa.shape<3>) -> tensor<39x21x6xf8E5M2>
+ return %0 : tensor<39x21x6xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: transpose_f8E5M2
+func.func @test_transpose_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2> {
+ %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf8E5M2>, tensor<3xi32>) -> tensor<3x13x21xf8E5M2>
+ return %1 : tensor<3x13x21xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: gather_f8E5M2
+func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2>
+ return %0 : tensor<13x26x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: scatter_f8E5M2
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+ return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: argmax_f8E4M3FN
+func.func @test_argmax_f8E4M3FN(%arg0: tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32> {
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32>
+ return %0 : tensor<12x16xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f8E4M3FN
+func.func @test_avg_pool2d_f8E4M3FN(%arg0: tensor<1x7x7x9xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN> {
+ %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN>
+ return %0 : tensor<1x7x7x9xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: conv2d_f8E4M3FN
+func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8x1x1x4xf8E4M3FN>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E4M3FN>, tensor<8x1x1x4xf8E4M3FN>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: conv3d_f8E4M3FN
+func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<34xf16>) -> tensor<1x4x8x21x34xf16> {
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<34xf16>) -> tensor<1x4x8x21x34xf16>
+ return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E4M3FN
+func.func @test_depthwise_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<1x1x4x2xf8E4M3FN>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E4M3FN>, tensor<1x1x4x2xf8E4M3FN>, tensor<8xf16>) -> tensor<1x4x4x8xf16>
+ return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: matmul_f8E4M3FN
+func.func @test_matmul_f8E4M3FN(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
+ %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f8E4M3FN
+func.func @test_max_pool2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN> {
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN>
+ return %0 : tensor<1x32x32x8xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: transpose_conv2d_f8E4M3FN
+func.func @test_transpose_conv2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>, %arg1: tensor<16x1x1x8xf8E4M3FN>, %arg2: tensor<16xf16>) -> tensor<1x32x32x16xf16> {
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>, tensor<16x1x1x8xf8E4M3FN>, tensor<16xf16>) -> tensor<1x32x32x16xf16>
+ return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+// CHECK-LABEL: const_f8E4M3FN
+func.func @test_const_f8E4M3FN(%arg0 : index) -> tensor<4xf8E4M3FN> {
+ %0 = "tosa.const"() {value = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E4M3FN>} : () -> tensor<4xf8E4M3FN>
+ return %0 : tensor<4xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: cast_f8E4M3FN
+func.func @test_cast_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+ return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: concat_f8E4M3FN
+func.func @test_concat_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN> {
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>, tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN>
+ return %0 : tensor<26x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: pad_f8E4M3FN
+func.func @test_pad_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+ %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %cst = "tosa.const"() { value = dense<-0.0> : tensor<1xf8E4M3FN> } : () -> tensor<1xf8E4M3FN>
+ %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<1xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+ return %0 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: reshape_f8E4M3FN
+func.func @test_reshape_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1x819xf8E4M3FN> {
+ %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<2>) -> tensor<1x819xf8E4M3FN>
+ return %0 : tensor<1x819xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: reverse_f8E4M3FN
+func.func @test_reverse_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+ return %0 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: slice_f8E4M3FN
+func.func @test_slice_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<4x11x1xf8E4M3FN> {
+ %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E4M3FN>
+ return %2 : tensor<4x11x1xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: tile_f8E4M3FN
+func.func @test_tile_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<39x21x6xf8E4M3FN> {
+ %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>) -> tensor<39x21x6xf8E4M3FN>
+ return %0 : tensor<39x21x6xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: transpose_f8E4M3FN
+func.func @test_transpose_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN> {
+ %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf8E4M3FN>, tensor<3xi32>) -> tensor<3x13x21xf8E4M3FN>
+ return %1 : tensor<3x13x21xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: gather_f8E4M3FN
+func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN>
+ return %0 : tensor<13x26x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: scatter_f8E4M3FN
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+ return %0 : tensor<13x21x3xf8E4M3FN>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 6562a7c2ab55c..3a03f13522db2 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -59,13 +59,13 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
// CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>}
// CHECK-DAG: %[[pad:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
- // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+ // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>}
// CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 6]> : tensor<4xindex>}
// CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>}
// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[reIn:.+]] = tosa.reshape %arg0, %[[CONST0]]
- // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+ // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<1xf32>) -> tensor<4x12x12x2x1xf32>
// CHECK: %[[reArg1:.+]] = tosa.reshape %arg1, %[[CONST3]]
// CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]], %[[SHIFT]]
// CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]], %[[CONST4]]
More information about the Mlir-commits
mailing list