[Mlir-commits] [mlir] Fix TOSA FP16->INT16 CAST lowering (PR #79299)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 24 06:57:03 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Thomas Preud'homme (RoboTux)
<details>
<summary>Changes</summary>
Currently cast from FP to int is implemented by clamping on the min and max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits. This patch
instead use a select acting on a compare against max int + 1 which is
representable in floating-point.
---
Full diff: https://github.com/llvm/llvm-project/pull/79299.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+38-8)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+14-12)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 647592395c8760..96de43caae7364 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,23 +480,53 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
- auto intMin = rewriter.create<arith::ConstantOp>(
+ auto intMinFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- auto intMax = rewriter.create<arith::ConstantOp>(
+ auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+ // The input floating-point type has enough mantissa bits to represent
+ // the max int value so just clamp the input in the floating-point
+ // domain and convert to int. Note: the min value can be represented
+ // because it consists of a mantissa with only the lsb set.
+ if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+ dstTy.getIntOrFloatBitWidth() - 1) {
+ auto intMaxFP = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
+
+ auto clamped =
+ clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+ return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ }
+
+ // Otherwise, we can rely on int max + 1 being representable because it
+ // also consists of a single lsb set in the mantissa. So clamp the min
+ // value and compare against that to select the max int value if needed.
+ auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
-
- auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+ .getSExtValue() +
+ 1));
- auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
-
- return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ auto intMax = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto minClampedFP =
+ rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+ auto minClamped =
+ rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
+ auto overflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+ return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+ minClamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f63b7d5ca6c8b..b19f9a04bd6f3b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -2.14748365E+9
- // CHECK: arith.constant 2.14748365E+9
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -2.14748365E+9 : f32
+ // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f32
+ // CHECK: [[CSTMAXP1:%[a-z0-9_]+]] = arith.constant 2.14748365E+9 : f32
+ // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 2147483647 : i32
+ // CHECK: [[MAX:%[a-z0-9_]+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+ // CHECK: [[CONV:%[a-z0-9_]+]] = arith.fptosi [[MAX]] : f32 to i32
+ // CHECK: [[CMP:%[a-z0-9_]+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+ // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
%20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
// CHECK: linalg.generic
@@ -552,12 +554,12 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
%0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -1.280000e+02
- // CHECK: arith.constant 1.270000e+02
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -1.280000e+02 : f16
+ // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+ // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 1.270000e+02 : f16
+ // CHECK: [[MIN:%[a-z0-9_]+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+ // CHECK: [[CLAMP:%[a-z0-9_]+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+ // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/79299
More information about the Mlir-commits
mailing list