[Mlir-commits] [mlir] 47286fc - [mlir][tosa] Add tosa.cast to linalg lowering

Rob Suderman llvmlistbot at llvm.org
Fri Mar 19 11:49:14 PDT 2021


Author: Rob Suderman
Date: 2021-03-19T11:48:37-07:00
New Revision: 47286fc530159dfdbc28f14daaeff4066a1f3b1e

URL: https://github.com/llvm/llvm-project/commit/47286fc530159dfdbc28f14daaeff4066a1f3b1e
DIFF: https://github.com/llvm/llvm-project/commit/47286fc530159dfdbc28f14daaeff4066a1f3b1e.diff

LOG: [mlir][tosa] Add tosa.cast to linalg lowering

Handles lowering from the tosa CastOp to the equivalent linalg lowering. It
includes support for interchange between bool, int, and floating point.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D98828

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 903e4cc765aa..72b9aa850213 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -289,6 +289,67 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                      rewriter);
   }
 
+  // tosa::CastOp
+  if (isa<tosa::CastOp>(op)) {
+    Type srcTy = elementTy;
+    Type dstTy = resultTypes.front();
+    bool bitExtend =
+        srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
+
+    if (srcTy == dstTy)
+      return args.front();
+
+    if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
+      return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None);
+
+    if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
+      return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args,
+                                              mlir::None);
+
+    // 1-bit integers need to be treated as signless.
+    if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy))
+      return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args,
+                                             mlir::None);
+
+    if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
+      return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args,
+                                                  mlir::None);
+
+    // All other si-to-fp conversions should be handled by SIToFP.
+    if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy))
+      return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args,
+                                             mlir::None);
+
+    // Casting to boolean, floats need to only be checked as not-equal to zero.
+    if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
+      Value zero =
+          rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0));
+      return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE,
+                                           args.front(), zero);
+    }
+
+    if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy))
+      return rewriter.create<mlir::FPToSIOp>(loc, resultTypes, args,
+                                             mlir::None);
+
+    // Casting to boolean, integers need to only be checked as not-equal to
+    // zero.
+    if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
+      Value zero =
+          rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth());
+      return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
+                                           zero);
+    }
+
+    if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
+      return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
+                                                  mlir::None);
+
+    if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend)
+      return rewriter.create<mlir::TruncateIOp>(loc, resultTypes, args,
+                                                mlir::None);
+  }
+
   (void)rewriter.notifyMatchFailure(
       op, "unhandled op for linalg body calculation for elementwise op");
   return nullptr;
@@ -891,7 +952,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::LogicalAndOp>,
       PointwiseConverter<tosa::LogicalNotOp>,
       PointwiseConverter<tosa::LogicalOrOp>,
-      PointwiseConverter<tosa::LogicalXorOp>,
+      PointwiseConverter<tosa::LogicalXorOp>, PointwiseConverter<tosa::CastOp>,
       PointwiseConverter<tosa::LogicalLeftShiftOp>,
       PointwiseConverter<tosa::LogicalRightShiftOp>,
       PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6f99d782d3af..f25eb3f346ba 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -180,6 +180,35 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: select
   %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
+  // CHECK: linalg.generic
+  // CHECK: fptosi
+  %19 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: constant 0
+  // CHECK: cmpf
+  %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: fptrunc
+  %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
+
+  // CHECK: linalg.generic
+  // CHECK: yield
+  %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_simple_f16
+func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
+
+  // CHECK: linalg.generic
+  // CHECK: fpext
+  %0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32>
+
   return
 }
 
@@ -255,6 +284,27 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: select
   %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: linalg.generic
+  // CHECK: trunci
+  %16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+
+  // CHECK: linalg.generic
+  // CHECK: yield
+  %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: sexti
+  %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+
+  // CHECK: linalg.generic
+  // CHECK: constant 0
+  // CHECK: cmpi
+  %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: sitofp
+  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+
   return
 }
 


        


More information about the Mlir-commits mailing list