[Mlir-commits] [mlir] 21724dd - [MLIR][TOSA] Comparison based elementwise operations for tosa-to-linalg
Rob Suderman
llvmlistbot at llvm.org
Mon Feb 1 21:44:40 PST 2021
Author: natashaknk
Date: 2021-02-01T21:37:52-08:00
New Revision: 21724ddcb7033cb010d57ff1a2d593cd70d462f5
URL: https://github.com/llvm/llvm-project/commit/21724ddcb7033cb010d57ff1a2d593cd70d462f5
DIFF: https://github.com/llvm/llvm-project/commit/21724ddcb7033cb010d57ff1a2d593cd70d462f5.diff
LOG: [MLIR][TOSA] Comparison based elementwise operations for tosa-to-linalg
Comitted log, exp, maximum, minimum, comparison, ceil and floor conversions from TOSA to LinAlg. Support for signless integer and floating point.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D95839
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6ad5a256f98a..dd4bb2a3d016 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -29,7 +29,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
- op->getResult(0).getType().cast<ShapedType>().getElementType();
+ op->getOperand(0).getType().cast<ShapedType>().getElementType();
// tosa::AbsOp
if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
@@ -66,6 +66,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::PowFOp>(loc, resultTypes, args);
+ // tosa::LogOp
+ if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::LogOp>(loc, resultTypes, args);
+
+ // tosa::ExpOp
+ if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::ExpOp>(loc, resultTypes, args);
+
// tosa::SubOp
if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
@@ -77,6 +85,58 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::TanhOp>(loc, resultTypes, args);
+ // tosa::GreaterOp
+ if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
+ args[1]);
+
+ if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
+ return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
+ args[1]);
+
+ // tosa::GreaterEqualOp
+ if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
+ args[1]);
+
+ if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
+ return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
+ args[1]);
+
+ // tosa::MaximumOp
+ if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
+ auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
+ args[0], args[1]);
+ return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+ }
+
+ if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
+ auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
+ args[0], args[1]);
+ return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+ }
+
+ // tosa::MinimumOp
+ if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
+ auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
+ args[0], args[1]);
+ return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+ }
+
+ if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
+ auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
+ args[0], args[1]);
+ return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+ }
+
+ // tosa::CeilOp
+ if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
+
+ // tosa::FloorOp
+ if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
+
rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
@@ -94,19 +154,21 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
// For now require no broadcasting. Consider making it support broadcasting
// operations.
- Type uniqueTy = operation->getOperand(0).getType();
+ Type uniqueInTy = operation->getOperand(0).getType();
bool allInputTypesEqual =
llvm::all_of(operation->getOperandTypes(),
- [&](Type operandTy) { return operandTy == uniqueTy; });
+ [&](Type operandTy) { return operandTy == uniqueInTy; });
if (!allInputTypesEqual)
return rewriter.notifyMatchFailure(operation,
"All operands must have the same type");
- bool allResultTypesEqual =
- llvm::all_of(operation->getResultTypes(),
- [&](Type resultTy) { return resultTy == uniqueTy; });
- if (!allResultTypesEqual)
+ bool resultAndInputShapeEqual =
+ llvm::all_of(operation->getResultTypes(), [&](Type resultTy) {
+ return resultTy.cast<ShapedType>().getShape() == t0.getShape();
+ });
+
+ if (!resultAndInputShapeEqual)
return rewriter.notifyMatchFailure(
- operation, "All results must have the same type as the input");
+ operation, "All results must have the same shape as the input");
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type> bodyArgTypes;
@@ -179,10 +241,16 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
- PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::AbsOp>,
+ PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::LogOp>,
+ PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseXorOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
- PointwiseConverter<tosa::LogicalRightShiftOp>>(context);
+ PointwiseConverter<tosa::LogicalRightShiftOp>,
+ PointwiseConverter<tosa::GreaterOp>,
+ PointwiseConverter<tosa::GreaterEqualOp>,
+ PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
+ PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>>(
+ context);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6f849f7030d3..e416246a19a4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -100,6 +100,41 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: pow
%4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: log
+ %5 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: exp
+ %6 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpf
+ %7 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpf
+ %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpf
+ // CHECK: select
+ %9 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpf
+ // CHECK: select
+ %10 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: ceil
+ %11 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: floor
+ %12 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
return
}
@@ -135,6 +170,25 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: shift_right_unsigned
%6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK: cmpi
+ %7 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpi
+ %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpi
+ // CHECK: select
+ %9 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpi
+ // CHECK: select
+ %10 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+
return
}
More information about the Mlir-commits
mailing list