[Mlir-commits] [mlir] 07ce5c9 - [mlir][tosa] Add lowerings for tosa.equal and tosa.arithmetic_right_shift
Rob Suderman
llvmlistbot at llvm.org
Mon May 3 18:31:33 PDT 2021
Author: natashaknk
Date: 2021-05-03T18:26:49-07:00
New Revision: 07ce5c99d791a43efeefbbae30f951703e84bc46
URL: https://github.com/llvm/llvm-project/commit/07ce5c99d791a43efeefbbae30f951703e84bc46
DIFF: https://github.com/llvm/llvm-project/commit/07ce5c99d791a43efeefbbae30f951703e84bc46.diff
LOG: [mlir][tosa] Add lowerings for tosa.equal and tosa.arithmetic_right_shift
Lowerings equal and arithmetic_right_shift for elementwise ops to linalg dialect using linalg.generic
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D101804
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8107d97cb61f2..ee4f29c0e1ca1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -227,6 +227,45 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
+ // tosa::ArithmeticRightShiftOp
+ if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
+ auto result =
+ rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
+ auto round = op->getAttr("round").cast<BoolAttr>().getValue();
+ if (!round) {
+ return result;
+ }
+
+ Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
+ auto one =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
+ auto zero =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto i1one =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
+
+ // Checking that input2 != 0
+ auto shiftValueGreaterThanZero =
+ rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
+
+ // Checking for the last bit of input1 to be 1
+ auto subtract =
+ rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
+ auto shifted = rewriter
+ .create<mlir::SignedShiftRightOp>(loc, resultTypes,
+ args[0], subtract)
+ ->getResults();
+ auto truncated =
+ rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
+ auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
+
+ auto shouldRound = rewriter.create<mlir::AndOp>(
+ loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+ auto extended =
+ rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
+ return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
+ }
+
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
@@ -284,6 +323,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
args[1]);
+ // tosa::EqualOp
+ if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
+ args[1]);
+
+ if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
+ return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
+ args[1]);
+
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
@@ -2202,9 +2250,11 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
PointwiseConverter<tosa::CastOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
+ PointwiseConverter<tosa::ArithmeticRightShiftOp>,
PointwiseConverter<tosa::SelectOp>,
PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
+ PointwiseConverter<tosa::EqualOp>,
PointwiseConverter<tosa::MaximumOp>,
PointwiseConverter<tosa::MinimumOp>,
PointwiseConverter<tosa::CeilOp>,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 775ac6d06ed4e..9bd03f10e68e3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -151,65 +151,69 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: cmpf
%11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+ // CHECK: linalg.generic
+ // CHECK: cmpf
+ %12 = "tosa.equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
// CHECK: linalg.generic
// CHECK: select
- %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: ceil
- %15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %16 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: floor
- %16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+ %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+ %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: negf
// CHECK: exp
// CHECK: addf
// CHECK: divf
- %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: fptosi
- %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
+ %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpf
- %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
+ %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: fptrunc
- %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
+ %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
// CHECK: linalg.generic
// CHECK: yield
- %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %24 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: divf
- %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %25 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
return
}
@@ -285,58 +289,76 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: shift_right_unsigned
%9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK: shift_right_signed
+ %10 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: constant 1
+ // CHECK: constant 0
+ // CHECK: constant true
+ // CHECK: cmpi
+ // CHECK: subi
+ // CHECK: shift_right_signed
+ // CHECK: trunci
+ // CHECK: and
+ // CHECK: and
+ // CHECK: zexti
+ // CHECK: addi
+ %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
// CHECK: linalg.generic
// CHECK: cmpi
- %10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %12 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: cmpi
- %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %13 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: select
- %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %14 = "tosa.select"(%12, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %15 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %16 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: trunci
- %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+ %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
// CHECK: linalg.generic
// CHECK: yield
- %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+ %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: sexti
- %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+ %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi
- %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+ %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: sitofp
- %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+ %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
return
}
More information about the Mlir-commits
mailing list