[Mlir-commits] [mlir] 1c5fd15 - [mlir][Tosa] Allow non-fp32 tosa.cast to integers

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Aug 29 07:11:14 PDT 2023


Author: Krzysztof Drewniak
Date: 2023-08-29T14:11:08Z
New Revision: 1c5fd1534cfd95fb4ce6356aa7719c7dbe37bee9

URL: https://github.com/llvm/llvm-project/commit/1c5fd1534cfd95fb4ce6356aa7719c7dbe37bee9
DIFF: https://github.com/llvm/llvm-project/commit/1c5fd1534cfd95fb4ce6356aa7719c7dbe37bee9.diff

LOG: [mlir][Tosa] Allow non-fp32 tosa.cast to integers

Fix the lowering of tosa.cast to create attributes of the input source
type when casting from floats to integers.

This is motivated by the need to cast fp16 to i9, which we have
encountered in certain quantized models.

Reviewed By: eric-k256, jpienaar

Differential Revision: https://reviews.llvm.org/D158738

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index bfd08ad389610a..cf560d49b5094f 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -481,12 +481,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
       auto intMin = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getF32FloatAttr(
+          loc, rewriter.getFloatAttr(
+                   getElementTypeOrSelf(srcTy),
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
       auto intMax = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getF32FloatAttr(
+          loc, rewriter.getFloatAttr(
+                   getElementTypeOrSelf(srcTy),
                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 8029eae47ba4a7..c12b801e39ad6f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -551,6 +551,14 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: arith.extf
   %0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
 
+  // CHECK: linalg.generic
+  // CHECK: arith.constant -1.280000e+02
+  // CHECK: arith.constant 1.270000e+02
+  // CHECK: math.roundeven
+  // CHECK: arith.minf
+  // CHECK: arith.maxf
+  // CHECK: arith.fptosi
+  %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
   return
 }
 


        


More information about the Mlir-commits mailing list