[Mlir-commits] [mlir] [mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes (PR #69192)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 16 05:03:50 PDT 2023


https://github.com/fabrizio-indirli created https://github.com/llvm/llvm-project/pull/69192

In TOSA MLIR dialect, fix the definition of the Clamp op to accept fp16 & bf16 datatypes for the min_fp and max_fp attributes. 
Add related test cases in Tosa/ops.mlir.

>From 15b73988c4f8381de5056441a1e75476b5f7251c Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 16 Oct 2023 10:07:19 +0100
Subject: [PATCH] [mlir][Tosa] fix fp16/bf16 support for Clamp min/max
 attributes

In TOSA MLIR dialect, fix the definition of the Clamp op to
accept fp16 & bf16 datatype for the min_fp and max_fp attributes.
Add related test cases in Tosa/ops.mlir.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td       |  4 ++--
 mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td | 11 +++++++++++
 mlir/test/Dialect/Tosa/ops.mlir                    | 14 ++++++++++++++
 3 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index a80111aedfe0b59..59220bffa27ca7e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -380,8 +380,8 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
     Tosa_Tensor:$input,
     I64Attr:$min_int,
     I64Attr:$max_int,
-    F32Attr:$min_fp,
-    F32Attr:$max_fp
+    Tosa_FloatAttr:$min_fp,
+    Tosa_FloatAttr:$max_fp
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index e39f4662e791918..34dd30f37bf4d62 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -197,6 +197,17 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
 def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
 def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
 
+def Tosa_FloatAttr : Attr<And<[CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
+                                Or<[CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isBF16()">,
+                                    CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF16()">,
+                                    CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF32()">,
+                                    ]>
+                                  ]>,
+                     "BF16/F16/F32 float attribute"> {
+  let storageType = [{ ::mlir::FloatAttr }];
+  let returnType = [{ ::mlir::APFloat }];
+}
+
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..e18943a0fd613d5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -124,6 +124,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: clamp_f16
+func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
+  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
+  return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: clamp_bf16
+func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
+  %0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
+  return %0 : tensor<13x21x3xbf16>
+}
+
 // -----
 // CHECK-LABEL: sigmoid
 func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {



More information about the Mlir-commits mailing list