[Mlir-commits] [mlir] Fix TOSA FP16->INT16 CAST lowering (PR #79299)

Thomas Preud'homme llvmlistbot at llvm.org
Fri Jan 26 02:31:26 PST 2024


https://github.com/RoboTux updated https://github.com/llvm/llvm-project/pull/79299

>From 52cda74cb231971ec87248838844fba8151b466d Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Mon, 22 Jan 2024 17:06:00 +0000
Subject: [PATCH 1/2] Fix TOSA FP16->INT16 CAST lowering

Currently cast from FP to int is implemented by clamping on the min and max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits. This patch
instead use a select acting on a compare against max int + 1 which is
representable in floating-point.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 46 +++++++++++++++----
 .../TosaToLinalg/tosa-to-linalg.mlir          | 26 ++++++-----
 2 files changed, 52 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 647592395c87607..96de43caae7364f 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,23 +480,53 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     }
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
-      auto intMin = rewriter.create<arith::ConstantOp>(
+      auto intMinFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+      // The input floating-point type has enough mantissa bits to represent
+      // the max int value so just clamp the input in the floating-point
+      // domain and convert to int. Note: the min value can be represented
+      // because it consists of a mantissa with only the lsb set.
+      if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+          dstTy.getIntOrFloatBitWidth() - 1) {
+        auto intMaxFP = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(
+                     getElementTypeOrSelf(srcTy),
+                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+                         .getSExtValue()));
+
+        auto clamped =
+            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+        return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+      }
+
+      // Otherwise, we can rely on int max + 1 being representable because it
+      // also consists of a single lsb set in the mantissa. So clamp the min
+      // value and compare against that to select the max int value if needed.
+      auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
-                       .getSExtValue()));
-
-      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+                           .getSExtValue() +
+                       1));
 
-      auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
-
-      return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+      auto intMax = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIntegerAttr(
+                   getElementTypeOrSelf(dstTy),
+                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+      auto minClampedFP =
+          rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+      auto minClamped =
+          rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
+      auto overflow = rewriter.create<arith::CmpFOp>(
+          loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+                                              minClamped);
     }
 
     // Casting to boolean, integers need to only be checked as not-equal to
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f63b7d5ca6c8be..b19f9a04bd6f3b4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.constant -2.14748365E+9
-  // CHECK: arith.constant 2.14748365E+9
-  // CHECK: math.roundeven
-  // CHECK: arith.minimumf
-  // CHECK: arith.maximumf
-  // CHECK: arith.fptosi
+  // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f32
+  // CHECK: [[CSTMAXP1:%[a-z0-9_]+]] = arith.constant 2.14748365E+9 : f32
+  // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 2147483647 : i32
+  // CHECK: [[MAX:%[a-z0-9_]+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CONV:%[a-z0-9_]+]] = arith.fptosi [[MAX]] : f32 to i32
+  // CHECK: [[CMP:%[a-z0-9_]+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -552,12 +554,12 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   %0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.constant -1.280000e+02
-  // CHECK: arith.constant 1.270000e+02
-  // CHECK: math.roundeven
-  // CHECK: arith.minimumf
-  // CHECK: arith.maximumf
-  // CHECK: arith.fptosi
+  // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -1.280000e+02 : f16
+  // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+  // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 1.270000e+02 : f16
+  // CHECK: [[MIN:%[a-z0-9_]+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMP:%[a-z0-9_]+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+  // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
   return
 }

>From 1f0d691b78285ed04df35e8f02908c905eac5baa Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Fri, 26 Jan 2024 10:23:04 +0000
Subject: [PATCH 2/2] Address F16->I32 case and clarify comments

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 50 +++++++++++++++----
 .../TosaToLinalg/tosa-to-linalg.mlir          | 37 +++++++++-----
 2 files changed, 66 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 96de43caae7364f..5ffc5311b6fd83c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,18 +480,50 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     }
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
+      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+      const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
+      // The range of integer values is wider than floating-point integral
+      // values so we only need to clamp infinites values.
+      if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
+          APFloat::semanticsMaxExponent(fltSemantics)) {
+        auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
+        auto posInf = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+                                       APFloat::getInf(fltSemantics)));
+        auto negInf = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(
+                     getElementTypeOrSelf(srcTy),
+                     APFloat::getInf(fltSemantics, /*Negative=*/true)));
+        auto overflow = rewriter.create<arith::CmpFOp>(
+            loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+        auto underflow = rewriter.create<arith::CmpFOp>(
+            loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+        auto intMin = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(
+                     getElementTypeOrSelf(dstTy),
+                     APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+        auto intMax = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(
+                     getElementTypeOrSelf(dstTy),
+                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+        auto maxClamped =
+            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+        return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
+                                                maxClamped);
+      }
+
       auto intMinFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
-      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
-
       // The input floating-point type has enough mantissa bits to represent
-      // the max int value so just clamp the input in the floating-point
-      // domain and convert to int. Note: the min value can be represented
-      // because it consists of a mantissa with only the lsb set.
+      // the max int value (n-1 bits set for a n-bit integer) so just clamp the
+      // input in the floating-point domain and convert to int. Note: the min
+      // value can be represented in the mantissa because, being a power of 2,
+      // it consists of a single leading bit.
       if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
           dstTy.getIntOrFloatBitWidth() - 1) {
         auto intMaxFP = rewriter.create<arith::ConstantOp>(
@@ -500,14 +532,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                          .getSExtValue()));
 
-        auto clamped =
+        Value clamped =
             clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
         return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
       }
 
-      // Otherwise, we can rely on int max + 1 being representable because it
-      // also consists of a single lsb set in the mantissa. So clamp the min
-      // value and compare against that to select the max int value if needed.
+      // Otherwise, we can rely on int max + 1 being representable because
+      // it's just int min with a positive sign. So clamp the min value and
+      // compare against that to select the max int value if needed.
       auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index b19f9a04bd6f3b4..fc22a436526a6fc 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,13 +514,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -2.14748365E+9 : f32
-  // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f32
-  // CHECK: [[CSTMAXP1:%[a-z0-9_]+]] = arith.constant 2.14748365E+9 : f32
-  // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 2147483647 : i32
-  // CHECK: [[MAX:%[a-z0-9_]+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
-  // CHECK: [[CONV:%[a-z0-9_]+]] = arith.fptosi [[MAX]] : f32 to i32
-  // CHECK: [[CMP:%[a-z0-9_]+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+  // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
+  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
+  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
+  // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
   // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
@@ -554,13 +554,26 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   %0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -1.280000e+02 : f16
-  // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f16
-  // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 1.270000e+02 : f16
-  // CHECK: [[MIN:%[a-z0-9_]+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
-  // CHECK: [[CLAMP:%[a-z0-9_]+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+  // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
+  // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
+  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
   // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
+
+  // CHECK: linalg.generic
+  // CHECK: [[ROUND:%.+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+  // CHECK: [[CONV:%.+]] = arith.fptosi [[ROUND]] : f16 to i32
+  // CHECK: [[POSINF:%.+]] = arith.constant 0x7C00 : f16
+  // CHECK: [[NEGINF:%.+]] = arith.constant 0xFC00 : f16
+  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
+  // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
+  // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
+  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
+  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+  // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
+  %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
   return
 }
 



More information about the Mlir-commits mailing list