[Mlir-commits] [mlir] 25b4a6a - [MLIR][TOSA] Add lowering from TOSA to Linalg for math-based and elementwise ops
Rob Suderman
llvmlistbot at llvm.org
Thu Feb 18 12:17:15 PST 2021
Author: natashaknk
Date: 2021-02-18T12:10:10-08:00
New Revision: 25b4a6a7f038184ba77dd3c0d8605da454bb4a06
URL: https://github.com/llvm/llvm-project/commit/25b4a6a7f038184ba77dd3c0d8605da454bb4a06
DIFF: https://github.com/llvm/llvm-project/commit/25b4a6a7f038184ba77dd3c0d8605da454bb4a06.diff
LOG: [MLIR][TOSA] Add lowering from TOSA to Linalg for math-based and elementwise ops
This patch adds lowering to Linalg for the following TOSA ops: negate, rsqrt, mul, select, clamp and reluN and includes support for signless integer and floating point types
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D96924
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 63c99f648392..8e096e48d2d3 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -24,6 +24,28 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
+template <typename T>
+static mlir::ConstantOp
+createConstFromIntAttribute(Operation *op, std::string attrName,
+ Type requiredAttrType, PatternRewriter &rewriter) {
+ auto castedN = static_cast<T>(
+ op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
+ return rewriter.create<mlir::ConstantOp>(
+ op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+}
+
+template <typename T, typename P>
+static mlir::SelectOp clampHelper(Operation *op, ValueRange args,
+ mlir::ConstantOp min, mlir::ConstantOp max,
+ P pred, PatternRewriter &rewriter) {
+ Location loc = op->getLoc();
+ auto smallerThanMin = rewriter.create<T>(loc, pred, args[0], min);
+ auto minOrArg =
+ rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, args[0]);
+ auto largerThanMax = rewriter.create<T>(loc, pred, max, args[0]);
+ return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
+}
+
static Value
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
ArrayRef<Type> resultTypes,
@@ -43,6 +65,42 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
+ // tosa::SubOp
+ if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
+
+ if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
+
+ // tosa::MulOp
+ if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
+ if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
+ (void)rewriter.notifyMatchFailure(op,
+ "Cannot have shift value for float");
+ return nullptr;
+ }
+ return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
+ }
+
+ if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
+ auto mul =
+ rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
+ auto constant =
+ rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
+ return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
+ constant);
+ }
+
+ // tosa::NegateOp
+ if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>()) {
+ auto constant =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, -1));
+ return rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], constant);
+ }
+
+ if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
+
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
@@ -67,6 +125,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
+ // tosa::RsqrtOp
+ if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
+
// tosa::LogOp
if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
@@ -75,13 +137,6 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
- // tosa::SubOp
- if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
- return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
-
- if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
- return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
-
// tosa::TanhOp
if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
@@ -104,6 +159,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
args[1]);
+ // tosa::SelectOp
+ if (isa<tosa::SelectOp>(op)) {
+ elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
+ if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
+ }
+
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
@@ -138,6 +200,44 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
+ // tosa::ClampOp
+ if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
+ auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
+ op->getAttr("min_fp"));
+ auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
+ op->getAttr("max_fp"));
+ return clampHelper<mlir::CmpFOp>(op, args, min, max, CmpFPredicate::OLT,
+ rewriter);
+ }
+
+ if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
+ auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
+ rewriter);
+ auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
+ rewriter);
+ return clampHelper<mlir::CmpIOp>(op, args, min, max, CmpIPredicate::slt,
+ rewriter);
+ }
+
+ // tosa::ReluNOp
+ if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
+ auto zero =
+ rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
+ auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
+ op->getAttr("max_fp"));
+ return clampHelper<mlir::CmpFOp>(op, args, zero, n, CmpFPredicate::OLT,
+ rewriter);
+ }
+
+ if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
+ auto zero =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
+ rewriter);
+ return clampHelper<mlir::CmpIOp>(op, args, zero, n, CmpIPredicate::slt,
+ rewriter);
+ }
+
(void)rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
@@ -245,16 +345,19 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
- PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::LogOp>,
- PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
- PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
+ PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
+ PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
+ PointwiseConverter<tosa::LogOp>, PointwiseConverter<tosa::ExpOp>,
+ PointwiseConverter<tosa::AbsOp>, PointwiseConverter<tosa::TanhOp>,
+ PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseXorOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
- PointwiseConverter<tosa::GreaterOp>,
+ PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
- PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>>(
+ PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
+ PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>>(
context);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 8963544838e1..022421459d16 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -116,43 +116,69 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: subf
%3 = "tosa.sub"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: linalg.generic
+ // CHECK: mulf
+ %4 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: negf
+ %5 = "tosa.negate"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
// CHECK: linalg.generic
// CHECK: pow
- %4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %6 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: rsqrt
+ %7 = "tosa.rsqrt"(%1) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: log
- %5 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+ %8 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: exp
- %6 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+ %9 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
- %7 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+ %10 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: cmpf
- %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+ %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: select
+ %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %9 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %10 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: ceil
- %11 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: floor
- %12 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpf
+ // CHECK: select
+ %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpf
+ // CHECK: select
+ %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return
}
@@ -169,44 +195,65 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: subi
%1 = "tosa.sub"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK: muli
+ %2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: muli
+ %3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+
// CHECK: linalg.generic
// CHECK: and
- %2 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: or
- %3 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: xor
- %4 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: shift_left
- %5 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: shift_right_unsigned
- %6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
- %7 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: cmpi
- %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
+ // CHECK: select
+ %11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %9 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %10 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK: cmpi
+ // CHECK: select
+ %14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: cmpi
+ // CHECK: select
+ %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
return
}
More information about the Mlir-commits
mailing list