[Mlir-commits] [mlir] [mlir] [TOSA] Allow any floating point type (PR #91745)

Matthias Gehre llvmlistbot at llvm.org
Fri May 10 06:53:39 PDT 2024


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/91745

After #86509 allowed all integer types in TOSA ops, this PR allows TOSA ops on all floating point types.
This helps to experiment with `f64` and 8-bit float types when spec conformance is not required.

>From 0c68ac7bcab973e9e5b4d265379857f13ce49a35 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 6 May 2024 12:06:48 +0200
Subject: [PATCH] [mlir] [TOSA] Allow any floating point type

---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  6 +++---
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     | 21 ++++---------------
 .../Tosa/Transforms/TosaValidation.cpp        |  9 ++++----
 mlir/test/Dialect/Tosa/invalid.mlir           |  2 +-
 mlir/test/Dialect/Tosa/level_check.mlir       |  8 +++++++
 5 files changed, 20 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 97a36c49d01b3..7871b46724a03 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1857,11 +1857,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
   }];
 
   let arguments = (ins
-    Tosa_Tensor_Plus_F64:$input
+    Tosa_Tensor:$input
   );
 
   let results = (outs
-    Tosa_Tensor_Plus_F64:$output
+    Tosa_Tensor:$output
   );
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
@@ -1944,7 +1944,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
   );
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 3687891fe4b7c..14fc9c7a6730c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -71,28 +71,16 @@ def Tosa_QuantizedInt	: AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
                                      Tosa_QuantizedType<"int16", [16, 0], 1>,
                                      Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
-//===----------------------------------------------------------------------===//
-// Floating-point types.
-//===----------------------------------------------------------------------===//
-def Tosa_Float : AnyTypeOf<[
-                            F32,
-			    F16,
-			    BF16]>;
-
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
-def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
+def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                "number">;
 
-// Add F64 type support just for tosa::CastOp and tosa::ConstOp
-def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
-                               "number_plus_f64">;
-
 // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
 // tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
-                             Tosa_QuantizedInt, Tosa_Float]>;
+                             Tosa_QuantizedInt, AnyFloat]>;
 
 //===----------------------------------------------------------------------===//
 // Tensor types
@@ -101,18 +89,17 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
 def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
 def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
 
-def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
+def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
 def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
-def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
 
 // Must be ranked but no further constraints
 def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
-                                Tosa_Float.predicate]>, "tosa.dtype">;
+                                AnyFloat.predicate]>, "tosa.dtype">;
 
 class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
   AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 539501082fd3f..b78c372af77e6 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -506,11 +506,10 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
 }
 
 bool TosaValidation::isValidElementType(Type type) {
-  if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
-    return false;
-  }
-  if (type.isF64()) {
-    return false;
+  if (isa<FloatType>(type)) {
+    if (profile == TosaProfileEnum::BaseInference)
+      return false;
+    return type.isF32() || type.isF16() || type.isBF16();
   }
   if (auto intTy = dyn_cast<IntegerType>(type)) {
     if (intTy.isUnsigned()) {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 730ac41dd7a8d..cb38d4d81ca2e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
 // -----
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
-  // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
+  // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d8dd878051f18..9b652f2d0bd14 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -131,6 +131,14 @@ func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
 
 // -----
 
+func.func @test_const_f64(%arg0 : tensor<1xf64>) {
+  // expected-error at +1 {{'tosa.const' op is not profile-aligned: element type 'f64' is not legal}}
+  %0 = "tosa.const"() {value = dense<0.0> : tensor<1xf64>} : () -> tensor<1xf64>
+  return
+}
+
+// -----
+
 func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
   %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :



More information about the Mlir-commits mailing list