[Mlir-commits] [mlir] [mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes (PR #69192)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 07:55:08 PDT 2023
https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/69192
>From 0ef8ec4966b20d341f54af462d929d8bd249c32e Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 16 Oct 2023 10:07:19 +0100
Subject: [PATCH] [mlir][Tosa] fix fp16/bf16 support for Clamp min/max
attributes
In TOSA MLIR dialect, fix the definition of the Clamp op to
accept fp16 & bf16 datatype for the min_fp and max_fp attributes.
Add ClampOp verifier to check attributes types compatibility.
Add related test cases in Tosa/ops.mlir.
Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 6 +++--
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 6 +++++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 26 +++++++++++++++++++
mlir/test/Dialect/Tosa/ops.mlir | 14 ++++++++++
4 files changed, 50 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index a80111aedfe0b59..e924b995548d9a4 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -380,8 +380,8 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
Tosa_Tensor:$input,
I64Attr:$min_int,
I64Attr:$max_int,
- F32Attr:$min_fp,
- F32Attr:$max_fp
+ Tosa_FloatAttr:$min_fp,
+ Tosa_FloatAttr:$max_fp
);
let results = (outs
@@ -389,6 +389,8 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
);
let hasCanonicalizer = 1;
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index e39f4662e791918..c55ddaafdda76e2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -197,6 +197,12 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
+def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
+ "arbitrary float attribute"> {
+ let storageType = [{ ::mlir::FloatAttr }];
+ let returnType = [{ ::mlir::APFloat }];
+}
+
//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6db04fe38bcd356..8a779bcb6f315d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -266,6 +266,32 @@ LogicalResult tosa::AvgPool2dOp::verify() {
return emitOpError("input/output element types are incompatible.");
}
+LogicalResult tosa::ClampOp::verify() {
+ mlir::Type inputDT =
+ llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ mlir::Type maxFpDT = getMaxFpAttr().getType();
+ mlir::Type minFpDT = getMinFpAttr().getType();
+ mlir::Type outputDT =
+ llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+ unsigned dataTypeBitWidth = inputDT.getIntOrFloatBitWidth();
+
+ if (!(inputDT.getTypeID() == outputDT.getTypeID()))
+ return emitOpError("input/output element types are incompatible.");
+
+ // if input datatype is float, check that the two min/max_fp attributes share
+ // the same type and that their type is either the same of the input's
+ // datatype, or a float type whose bitwidth > input datatype bitwidth
+ if (!inputDT.isInteger(dataTypeBitWidth)) {
+ if (((maxFpDT.getTypeID() != minFpDT.getTypeID()) ||
+ (maxFpDT.getTypeID() != inputDT.getTypeID() &&
+ maxFpDT.getIntOrFloatBitWidth() <= inputDT.getIntOrFloatBitWidth())))
+ return emitOpError("min/max attributes types are incompatible with "
+ "input/output element types.");
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Quantization Builders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..e18943a0fd613d5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -124,6 +124,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: clamp_f16
+func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
+ %0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
+ return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: clamp_bf16
+func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
+ %0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
+ return %0 : tensor<13x21x3xbf16>
+}
+
// -----
// CHECK-LABEL: sigmoid
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
More information about the Mlir-commits
mailing list