[Mlir-commits] [mlir] b23e518 - Fix TOSA FP16->INT16 CAST lowering (#79299)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 08:06:12 PST 2024
Author: Thomas Preud'homme
Date: 2024-01-30T16:06:08Z
New Revision: b23e518ce0df5b0835aba245cda50379bd896374
URL: https://github.com/llvm/llvm-project/commit/b23e518ce0df5b0835aba245cda50379bd896374
DIFF: https://github.com/llvm/llvm-project/commit/b23e518ce0df5b0835aba245cda50379bd896374.diff
LOG: Fix TOSA FP16->INT16 CAST lowering (#79299)
Currently cast from FP to int is implemented by clamping on the min and
max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits.
This patch instead use a select acting on a compare against max int + 1
which is representable in floating-point. It also has a special lowering
for cases where the integer range is wider than the floating-point range
to clamp the infinite values.
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 647592395c876..1eb5678b41755 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,23 +480,88 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
- auto intMin = rewriter.create<arith::ConstantOp>(
+ auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+ const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
+ // Check whether neither int min nor int max can be represented in the
+ // input floating-point type due to too short exponent range.
+ if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
+ APFloat::semanticsMaxExponent(fltSemantics)) {
+ // Use cmp + select to replace infinites by int min / int max. Other
+ // integral values can be represented in the integer space.
+ auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
+ auto posInf = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics)));
+ auto negInf = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics, /*Negative=*/true)));
+ auto overflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+ auto underflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+ auto intMin = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+ auto intMax = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto maxClamped =
+ rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+ return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
+ maxClamped);
+ }
+
+ auto intMinFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- auto intMax = rewriter.create<arith::ConstantOp>(
+ // Check whether the mantissa has enough bits to represent int max.
+ if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+ dstTy.getIntOrFloatBitWidth() - 1) {
+ // Int min can also be represented since it is a power of two and thus
+ // consists of a single leading bit. Therefore we can clamp the input
+ // in the floating-point domain.
+
+ auto intMaxFP = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
+
+ Value clamped =
+ clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+ return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ }
+
+ // Due to earlier check we know exponant range is big enough to represent
+ // int min. We can therefore rely on int max + 1 being representable as
+ // well because it's just int min with a positive sign. So clamp the min
+ // value and compare against that to select the max int value if needed.
+ auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
-
- auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
-
- auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
+ .getSExtValue() +
+ 1));
- return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ auto intMax = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto minClampedFP =
+ rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+ auto minClamped =
+ rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
+ auto overflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+ return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+ minClamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f63b7d5ca6c8..fc22a436526a6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -2.14748365E+9
- // CHECK: arith.constant 2.14748365E+9
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
+ // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+ // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
+ // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
+ // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+ // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
+ // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+ // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
%20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
// CHECK: linalg.generic
@@ -552,13 +554,26 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
%0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -1.280000e+02
- // CHECK: arith.constant 1.270000e+02
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
+ // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
+ // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
+ // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+ // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+ // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK: [[ROUND:%.+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+ // CHECK: [[CONV:%.+]] = arith.fptosi [[ROUND]] : f16 to i32
+ // CHECK: [[POSINF:%.+]] = arith.constant 0x7C00 : f16
+ // CHECK: [[NEGINF:%.+]] = arith.constant 0xFC00 : f16
+ // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
+ // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
+ // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
+ // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
+ // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+ // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
+ %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
return
}
More information about the Mlir-commits
mailing list