[Mlir-commits] [mlir] dde7b80 - [mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes (#69192)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 20 01:44:04 PDT 2023


Author: fabrizio-indirli
Date: 2023-10-20T09:44:01+01:00
New Revision: dde7b80ed071dfb874b91e15f2ba413af4d9a6b5

URL: https://github.com/llvm/llvm-project/commit/dde7b80ed071dfb874b91e15f2ba413af4d9a6b5
DIFF: https://github.com/llvm/llvm-project/commit/dde7b80ed071dfb874b91e15f2ba413af4d9a6b5.diff

LOG: [mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes (#69192)

In TOSA MLIR dialect, fix the definition of the Clamp op to
accept fp16 & bf16 datatype for the min_fp and max_fp attributes.
Add ClampOp verifier to check attributes types compatibility.
Add related test cases in Tosa/ops.mlir.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index a80111aedfe0b59..5cc97469d14c314 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -380,8 +380,8 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
     Tosa_Tensor:$input,
     I64Attr:$min_int,
     I64Attr:$max_int,
-    F32Attr:$min_fp,
-    F32Attr:$max_fp
+    Tosa_FloatAttr:$min_fp,
+    Tosa_FloatAttr:$max_fp
   );
 
   let results = (outs
@@ -389,6 +389,7 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
   );
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index e39f4662e791918..c55ddaafdda76e2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -197,6 +197,12 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
 def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
 def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
 
+def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
+                          "arbitrary float attribute"> {
+  let storageType = [{ ::mlir::FloatAttr }];
+  let returnType = [{ ::mlir::APFloat }];
+}
+
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ff34183f9a030a8..2e9339c0ca2edc5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -309,6 +309,32 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   return emitOpError("input/output element types are incompatible.");
 }
 
+LogicalResult tosa::ClampOp::verify() {
+  mlir::Type inputETy =
+      llvm::cast<ShapedType>(getInput().getType()).getElementType();
+  mlir::Type maxFpType = getMaxFpAttr().getType();
+  mlir::Type minFpType = getMinFpAttr().getType();
+  mlir::Type outputETy =
+      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
+
+  if (inputETy != outputETy)
+    return emitOpError("input/output element types are incompatible.");
+
+  // if input datatype is float, check that the two min/max_fp attributes share
+  // the same type and that their type is either the same of the input's
+  // datatype, or a float type whose bitwidth > input datatype bitwidth
+  if (!inputETy.isInteger(dataTypeBitWidth)) {
+    if (((maxFpType != minFpType) ||
+         (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
+                                       inputETy.getIntOrFloatBitWidth())))
+      return emitOpError("min/max attributes types are incompatible with "
+                         "input/output element types.");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Quantization Builders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 8ce8fb73f29a504..064c9160480fdcb 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -138,6 +138,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: clamp_f16
+func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
+  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
+  return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: clamp_bf16
+func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
+  %0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
+  return %0 : tensor<13x21x3xbf16>
+}
+
 // -----
 // CHECK-LABEL: sigmoid
 func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {


        


More information about the Mlir-commits mailing list