[Mlir-commits] [mlir] 47286fc - [mlir][tosa] Add tosa.cast to linalg lowering
Rob Suderman
llvmlistbot at llvm.org
Fri Mar 19 11:49:14 PDT 2021
Author: Rob Suderman
Date: 2021-03-19T11:48:37-07:00
New Revision: 47286fc530159dfdbc28f14daaeff4066a1f3b1e
URL: https://github.com/llvm/llvm-project/commit/47286fc530159dfdbc28f14daaeff4066a1f3b1e
DIFF: https://github.com/llvm/llvm-project/commit/47286fc530159dfdbc28f14daaeff4066a1f3b1e.diff
LOG: [mlir][tosa] Add tosa.cast to linalg lowering
Handles lowering from the tosa CastOp to the equivalent linalg lowering. It
includes support for interchange between bool, int, and floating point.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D98828
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 903e4cc765aa..72b9aa850213 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -289,6 +289,67 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
rewriter);
}
+ // tosa::CastOp
+ if (isa<tosa::CastOp>(op)) {
+ Type srcTy = elementTy;
+ Type dstTy = resultTypes.front();
+ bool bitExtend =
+ srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
+
+ if (srcTy == dstTy)
+ return args.front();
+
+ if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
+ return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None);
+
+ if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
+ return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args,
+ mlir::None);
+
+ // 1-bit integers need to be treated as signless.
+ if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy))
+ return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args,
+ mlir::None);
+
+ if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
+ return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args,
+ mlir::None);
+
+ // All other si-to-fp conversions should be handled by SIToFP.
+ if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy))
+ return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args,
+ mlir::None);
+
+ // Casting to boolean, floats need to only be checked as not-equal to zero.
+ if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
+ Value zero =
+ rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0));
+ return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE,
+ args.front(), zero);
+ }
+
+ if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy))
+ return rewriter.create<mlir::FPToSIOp>(loc, resultTypes, args,
+ mlir::None);
+
+ // Casting to boolean, integers need to only be checked as not-equal to
+ // zero.
+ if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
+ Value zero =
+ rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth());
+ return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
+ zero);
+ }
+
+ if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
+ return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
+ mlir::None);
+
+ if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend)
+ return rewriter.create<mlir::TruncateIOp>(loc, resultTypes, args,
+ mlir::None);
+ }
+
(void)rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
@@ -891,7 +952,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
PointwiseConverter<tosa::LogicalAndOp>,
PointwiseConverter<tosa::LogicalNotOp>,
PointwiseConverter<tosa::LogicalOrOp>,
- PointwiseConverter<tosa::LogicalXorOp>,
+ PointwiseConverter<tosa::LogicalXorOp>, PointwiseConverter<tosa::CastOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6f99d782d3af..f25eb3f346ba 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -180,6 +180,35 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: select
%18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: linalg.generic
+ // CHECK: fptosi
+ %19 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: constant 0
+ // CHECK: cmpf
+ %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: fptrunc
+ %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
+
+ // CHECK: linalg.generic
+ // CHECK: yield
+ %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_simple_f16
+func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
+
+ // CHECK: linalg.generic
+ // CHECK: fpext
+ %0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32>
+
return
}
@@ -255,6 +284,27 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: select
%15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK: trunci
+ %16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+
+ // CHECK: linalg.generic
+ // CHECK: yield
+ %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: sexti
+ %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+
+ // CHECK: linalg.generic
+ // CHECK: constant 0
+ // CHECK: cmpi
+ %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: sitofp
+ %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+
return
}
More information about the Mlir-commits
mailing list