[Mlir-commits] [mlir] 6e8e91a - [MLIR][TOSA] Fix converting tosa.clamp and tosa.relu to linalg
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Jul 11 10:18:53 PDT 2022
Author: jungpark-mlir
Date: 2022-07-11T17:18:47Z
New Revision: 6e8e91a7b63c51f487ddfbe2d6b2372ea1310faf
URL: https://github.com/llvm/llvm-project/commit/6e8e91a7b63c51f487ddfbe2d6b2372ea1310faf
DIFF: https://github.com/llvm/llvm-project/commit/6e8e91a7b63c51f487ddfbe2d6b2372ea1310faf.diff
LOG: [MLIR][TOSA] Fix converting tosa.clamp and tosa.relu to linalg
Tosa to Linalg conversion crashes when input tensor is a float type other than fp32.
Because tosa.clamp and tosa.reluN have fp32 min/max attribute which is converted as arith.constant with the attribute type.
This commit fixes the crash by correctly setting the float constant type from the input tensor.
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D128630
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 5711ffe003f47..d14d5d7d136ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -369,10 +369,17 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
- auto min = rewriter.create<arith::ConstantOp>(loc, elementTy,
- op->getAttr("min_fp"));
- auto max = rewriter.create<arith::ConstantOp>(loc, elementTy,
- op->getAttr("max_fp"));
+ bool losesInfo = false;
+ APFloat min_apf = op->getAttr("min_fp").cast<FloatAttr>().getValue();
+ APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
+ min_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ auto min = rewriter.create<arith::ConstantOp>(
+ loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
+ auto max = rewriter.create<arith::ConstantOp>(
+ loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
arith::CmpFPredicate::OLT, rewriter);
}
@@ -410,8 +417,12 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
auto zero =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
- auto n = rewriter.create<arith::ConstantOp>(loc, elementTy,
- op->getAttr("max_fp"));
+ bool losesInfo = false;
+ APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
+ max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ auto n = rewriter.create<arith::ConstantOp>(
+ loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
arith::CmpFPredicate::OLT, rewriter);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 357a77c5ab780..193c49c4ee87b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -473,6 +473,22 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// -----
+// CHECK-LABEL: @test_clamp_f16
+func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
+ // CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]]
+ // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]]
+ // CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1
+ // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]]
+ %0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_bool
func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK: linalg.generic
More information about the Mlir-commits
mailing list