[Mlir-commits] [mlir] 3f8aafd - [mlir][tosa] Fix tosa.cast semantics to perform rounding/clipping

Rob Suderman llvmlistbot at llvm.org
Wed May 12 22:00:18 PDT 2021


Author: Rob Suderman
Date: 2021-05-12T21:53:53-07:00
New Revision: 3f8aafd7902722cc2039c7ef3d6747f8d49f81a6

URL: https://github.com/llvm/llvm-project/commit/3f8aafd7902722cc2039c7ef3d6747f8d49f81a6
DIFF: https://github.com/llvm/llvm-project/commit/3f8aafd7902722cc2039c7ef3d6747f8d49f81a6.diff

LOG: [mlir][tosa] Fix tosa.cast semantics to perform rounding/clipping

Rounding to integers requires rounding (for floating points) and clipping
to the min/max values of the destination range. Added this behavior and
updated tests appropriately.

Reviewed By: sjarus, silvas

Differential Revision: https://reviews.llvm.org/D102375

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ba5316ce4167..66e174760aee 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -491,9 +491,34 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                            args.front(), zero);
     }
 
-    if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy))
-      return rewriter.create<mlir::FPToSIOp>(loc, resultTypes, args,
-                                             mlir::None);
+    if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
+      auto zero =
+          rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
+      auto half =
+          rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
+
+      auto intMin = rewriter.create<ConstantOp>(
+          loc, rewriter.getF32FloatAttr(
+                   APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
+                       .getSExtValue()));
+
+      auto intMax = rewriter.create<ConstantOp>(
+          loc, rewriter.getF32FloatAttr(
+                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+                       .getSExtValue()));
+
+      auto added = rewriter.create<AddFOp>(loc, args[0], half);
+      auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
+      auto negative =
+          rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
+      auto rounded =
+          rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
+
+      auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
+                                               CmpFPredicate::OLT, rewriter);
+
+      return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
+    }
 
     // Casting to boolean, integers need to only be checked as not-equal to
     // zero.
@@ -508,9 +533,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
       return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
                                                   mlir::None);
 
-    if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend)
-      return rewriter.create<mlir::TruncateIOp>(loc, resultTypes, args,
-                                                mlir::None);
+    if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
+      auto intMin = rewriter.create<ConstantIntOp>(
+          loc,
+          APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
+              .getSExtValue(),
+          srcTy.getIntOrFloatBitWidth());
+
+      auto intMax = rewriter.create<ConstantIntOp>(
+          loc,
+          APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+              .getSExtValue(),
+          srcTy.getIntOrFloatBitWidth());
+
+      auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
+                                               CmpIPredicate::slt, rewriter);
+      return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
+    }
   }
 
   (void)rewriter.notifyMatchFailure(

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f0bb4c5cc883..46841a15bd63 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -213,6 +213,18 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
+  // CHECK: constant 0.000000e+00
+  // CHECK: constant 5.000000e-01
+  // CHECK: constant -2.14748365E+9
+  // CHECK: constant 2.14748365E+9
+  // CHECK: addf
+  // CHECK: subf
+  // CHECK: cmpf olt
+  // CHECK: select
+  // CHECK: cmpf olt
+  // CHECK: select
+  // CHECK: cmpf olt
+  // CHECK: select
   // CHECK: fptosi
   %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
 
@@ -358,6 +370,12 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
+  // CHECK: constant -32768
+  // CHECK: constant 32767
+  // CHECK: cmpi slt
+  // CHECK: select
+  // CHECK: cmpi slt
+  // CHECK: select
   // CHECK: trunci
   %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
 


        


More information about the Mlir-commits mailing list