[Mlir-commits] [mlir] Fix TOSA FP16->INT16 CAST lowering (PR #79299)
Thomas Preud'homme
llvmlistbot at llvm.org
Tue Jan 30 03:27:17 PST 2024
https://github.com/RoboTux updated https://github.com/llvm/llvm-project/pull/79299
>From 52cda74cb231971ec87248838844fba8151b466d Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Mon, 22 Jan 2024 17:06:00 +0000
Subject: [PATCH 1/3] Fix TOSA FP16->INT16 CAST lowering
Currently cast from FP to int is implemented by clamping on the min and max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits. This patch
instead use a select acting on a compare against max int + 1 which is
representable in floating-point.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 46 +++++++++++++++----
.../TosaToLinalg/tosa-to-linalg.mlir | 26 ++++++-----
2 files changed, 52 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 647592395c87..96de43caae73 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,23 +480,53 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
- auto intMin = rewriter.create<arith::ConstantOp>(
+ auto intMinFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- auto intMax = rewriter.create<arith::ConstantOp>(
+ auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+ // The input floating-point type has enough mantissa bits to represent
+ // the max int value so just clamp the input in the floating-point
+ // domain and convert to int. Note: the min value can be represented
+ // because it consists of a mantissa with only the lsb set.
+ if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+ dstTy.getIntOrFloatBitWidth() - 1) {
+ auto intMaxFP = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
+
+ auto clamped =
+ clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+ return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ }
+
+ // Otherwise, we can rely on int max + 1 being representable because it
+ // also consists of a single lsb set in the mantissa. So clamp the min
+ // value and compare against that to select the max int value if needed.
+ auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
-
- auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+ .getSExtValue() +
+ 1));
- auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
-
- return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ auto intMax = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto minClampedFP =
+ rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+ auto minClamped =
+ rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
+ auto overflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+ return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+ minClamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f63b7d5ca6c..b19f9a04bd6f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -2.14748365E+9
- // CHECK: arith.constant 2.14748365E+9
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -2.14748365E+9 : f32
+ // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f32
+ // CHECK: [[CSTMAXP1:%[a-z0-9_]+]] = arith.constant 2.14748365E+9 : f32
+ // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 2147483647 : i32
+ // CHECK: [[MAX:%[a-z0-9_]+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+ // CHECK: [[CONV:%[a-z0-9_]+]] = arith.fptosi [[MAX]] : f32 to i32
+ // CHECK: [[CMP:%[a-z0-9_]+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+ // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
%20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
// CHECK: linalg.generic
@@ -552,12 +554,12 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
%0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -1.280000e+02
- // CHECK: arith.constant 1.270000e+02
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -1.280000e+02 : f16
+ // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+ // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 1.270000e+02 : f16
+ // CHECK: [[MIN:%[a-z0-9_]+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+ // CHECK: [[CLAMP:%[a-z0-9_]+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+ // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
return
}
>From 1f0d691b78285ed04df35e8f02908c905eac5baa Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Fri, 26 Jan 2024 10:23:04 +0000
Subject: [PATCH 2/3] Address F16->I32 case and clarify comments
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 50 +++++++++++++++----
.../TosaToLinalg/tosa-to-linalg.mlir | 37 +++++++++-----
2 files changed, 66 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 96de43caae73..5ffc5311b6fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,18 +480,50 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
+ auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+ const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
+ // The range of integer values is wider than floating-point integral
+ // values so we only need to clamp infinites values.
+ if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
+ APFloat::semanticsMaxExponent(fltSemantics)) {
+ auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
+ auto posInf = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics)));
+ auto negInf = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics, /*Negative=*/true)));
+ auto overflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+ auto underflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+ auto intMin = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+ auto intMax = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto maxClamped =
+ rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+ return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
+ maxClamped);
+ }
+
auto intMinFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
-
// The input floating-point type has enough mantissa bits to represent
- // the max int value so just clamp the input in the floating-point
- // domain and convert to int. Note: the min value can be represented
- // because it consists of a mantissa with only the lsb set.
+ // the max int value (n-1 bits set for a n-bit integer) so just clamp the
+ // input in the floating-point domain and convert to int. Note: the min
+ // value can be represented in the mantissa because, being a power of 2,
+ // it consists of a single leading bit.
if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
dstTy.getIntOrFloatBitWidth() - 1) {
auto intMaxFP = rewriter.create<arith::ConstantOp>(
@@ -500,14 +532,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- auto clamped =
+ Value clamped =
clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
}
- // Otherwise, we can rely on int max + 1 being representable because it
- // also consists of a single lsb set in the mantissa. So clamp the min
- // value and compare against that to select the max int value if needed.
+ // Otherwise, we can rely on int max + 1 being representable because
+ // it's just int min with a positive sign. So clamp the min value and
+ // compare against that to select the max int value if needed.
auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index b19f9a04bd6f..fc22a436526a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,13 +514,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -2.14748365E+9 : f32
- // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f32
- // CHECK: [[CSTMAXP1:%[a-z0-9_]+]] = arith.constant 2.14748365E+9 : f32
- // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 2147483647 : i32
- // CHECK: [[MAX:%[a-z0-9_]+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
- // CHECK: [[CONV:%[a-z0-9_]+]] = arith.fptosi [[MAX]] : f32 to i32
- // CHECK: [[CMP:%[a-z0-9_]+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+ // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
+ // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+ // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
+ // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
+ // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+ // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
+ // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
// CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
%20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
@@ -554,13 +554,26 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
%0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -1.280000e+02 : f16
- // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f16
- // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 1.270000e+02 : f16
- // CHECK: [[MIN:%[a-z0-9_]+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
- // CHECK: [[CLAMP:%[a-z0-9_]+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+ // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
+ // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
+ // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
+ // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+ // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
// CHECK: arith.fptosi [[CLAMP]] : f16 to i8
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK: [[ROUND:%.+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+ // CHECK: [[CONV:%.+]] = arith.fptosi [[ROUND]] : f16 to i32
+ // CHECK: [[POSINF:%.+]] = arith.constant 0x7C00 : f16
+ // CHECK: [[NEGINF:%.+]] = arith.constant 0xFC00 : f16
+ // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
+ // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
+ // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
+ // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
+ // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+ // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
+ %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
return
}
>From e9162d6d6ca874ead7e981fdf9102eb916c8efd3 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 30 Jan 2024 11:21:42 +0000
Subject: [PATCH 3/3] Do not write "if" comments as statements
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 23 +++++++++++--------
1 file changed, 13 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 5ffc5311b6fd..1eb5678b4175 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -483,10 +483,12 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
- // The range of integer values is wider than floating-point integral
- // values so we only need to clamp infinites values.
+ // Check whether neither int min nor int max can be represented in the
+ // input floating-point type due to too short exponent range.
if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
APFloat::semanticsMaxExponent(fltSemantics)) {
+ // Use cmp + select to replace infinites by int min / int max. Other
+ // integral values can be represented in the integer space.
auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
auto posInf = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
@@ -519,13 +521,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- // The input floating-point type has enough mantissa bits to represent
- // the max int value (n-1 bits set for a n-bit integer) so just clamp the
- // input in the floating-point domain and convert to int. Note: the min
- // value can be represented in the mantissa because, being a power of 2,
- // it consists of a single leading bit.
+ // Check whether the mantissa has enough bits to represent int max.
if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
dstTy.getIntOrFloatBitWidth() - 1) {
+ // Int min can also be represented since it is a power of two and thus
+ // consists of a single leading bit. Therefore we can clamp the input
+ // in the floating-point domain.
+
auto intMaxFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
@@ -537,9 +539,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
}
- // Otherwise, we can rely on int max + 1 being representable because
- // it's just int min with a positive sign. So clamp the min value and
- // compare against that to select the max int value if needed.
+ // Due to earlier check we know exponant range is big enough to represent
+ // int min. We can therefore rely on int max + 1 being representable as
+ // well because it's just int min with a positive sign. So clamp the min
+ // value and compare against that to select the max int value if needed.
auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
More information about the Mlir-commits
mailing list