[Mlir-commits] [mlir] [mlir] [TOSA] Allow any floating point type (PR #91745)
Matthias Gehre
llvmlistbot at llvm.org
Fri May 10 06:53:39 PDT 2024
https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/91745
After #86509 allowed all integer types in TOSA ops, this PR allows TOSA ops on all floating point types.
This helps to experiment with `f64` and 8-bit float types when spec conformance is not required.
>From 0c68ac7bcab973e9e5b4d265379857f13ce49a35 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 6 May 2024 12:06:48 +0200
Subject: [PATCH] [mlir] [TOSA] Allow any floating point type
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 6 +++---
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 21 ++++---------------
.../Tosa/Transforms/TosaValidation.cpp | 9 ++++----
mlir/test/Dialect/Tosa/invalid.mlir | 2 +-
mlir/test/Dialect/Tosa/level_check.mlir | 8 +++++++
5 files changed, 20 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 97a36c49d01b3..7871b46724a03 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1857,11 +1857,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
}];
let arguments = (ins
- Tosa_Tensor_Plus_F64:$input
+ Tosa_Tensor:$input
);
let results = (outs
- Tosa_Tensor_Plus_F64:$output
+ Tosa_Tensor:$output
);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
@@ -1944,7 +1944,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
);
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 3687891fe4b7c..14fc9c7a6730c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -71,28 +71,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
-//===----------------------------------------------------------------------===//
-// Floating-point types.
-//===----------------------------------------------------------------------===//
-def Tosa_Float : AnyTypeOf<[
- F32,
- F16,
- BF16]>;
-
//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
-def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
+def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;
-// Add F64 type support just for tosa::CastOp and tosa::ConstOp
-def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
- "number_plus_f64">;
-
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
- Tosa_QuantizedInt, Tosa_Float]>;
+ Tosa_QuantizedInt, AnyFloat]>;
//===----------------------------------------------------------------------===//
// Tensor types
@@ -101,18 +89,17 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
-def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
+def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
-def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
// Must be ranked but no further constraints
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
- Tosa_Float.predicate]>, "tosa.dtype">;
+ AnyFloat.predicate]>, "tosa.dtype">;
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 539501082fd3f..b78c372af77e6 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -506,11 +506,10 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
}
bool TosaValidation::isValidElementType(Type type) {
- if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
- return false;
- }
- if (type.isF64()) {
- return false;
+ if (isa<FloatType>(type)) {
+ if (profile == TosaProfileEnum::BaseInference)
+ return false;
+ return type.isF32() || type.isF16() || type.isBF16();
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 730ac41dd7a8d..cb38d4d81ca2e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
// -----
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
- // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
+ // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d8dd878051f18..9b652f2d0bd14 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -131,6 +131,14 @@ func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
// -----
+func.func @test_const_f64(%arg0 : tensor<1xf64>) {
+ // expected-error at +1 {{'tosa.const' op is not profile-aligned: element type 'f64' is not legal}}
+ %0 = "tosa.const"() {value = dense<0.0> : tensor<1xf64>} : () -> tensor<1xf64>
+ return
+}
+
+// -----
+
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// expected-error at +1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
More information about the Mlir-commits
mailing list