[Mlir-commits] [mlir] 25b4a6a - [MLIR][TOSA] Add lowering from TOSA to Linalg for math-based and elementwise ops

Rob Suderman llvmlistbot at llvm.org
Thu Feb 18 12:17:15 PST 2021


Author: natashaknk
Date: 2021-02-18T12:10:10-08:00
New Revision: 25b4a6a7f038184ba77dd3c0d8605da454bb4a06

URL: https://github.com/llvm/llvm-project/commit/25b4a6a7f038184ba77dd3c0d8605da454bb4a06
DIFF: https://github.com/llvm/llvm-project/commit/25b4a6a7f038184ba77dd3c0d8605da454bb4a06.diff

LOG: [MLIR][TOSA] Add lowering from TOSA to Linalg for math-based and elementwise ops

This patch adds lowering to Linalg for the following TOSA ops: negate, rsqrt, mul, select, clamp and reluN and includes support for signless integer and floating point types

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D96924

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 63c99f648392..8e096e48d2d3 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -24,6 +24,28 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
 }
 
+template <typename T>
+static mlir::ConstantOp
+createConstFromIntAttribute(Operation *op, std::string attrName,
+                            Type requiredAttrType, PatternRewriter &rewriter) {
+  auto castedN = static_cast<T>(
+      op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
+  return rewriter.create<mlir::ConstantOp>(
+      op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+}
+
+template <typename T, typename P>
+static mlir::SelectOp clampHelper(Operation *op, ValueRange args,
+                                  mlir::ConstantOp min, mlir::ConstantOp max,
+                                  P pred, PatternRewriter &rewriter) {
+  Location loc = op->getLoc();
+  auto smallerThanMin = rewriter.create<T>(loc, pred, args[0], min);
+  auto minOrArg =
+      rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, args[0]);
+  auto largerThanMax = rewriter.create<T>(loc, pred, max, args[0]);
+  return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
+}
+
 static Value
 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                             ArrayRef<Type> resultTypes,
@@ -43,6 +65,42 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
     return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
 
+  // tosa::SubOp
+  if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
+
+  if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
+    return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
+
+  // tosa::MulOp
+  if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
+    if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
+      (void)rewriter.notifyMatchFailure(op,
+                                        "Cannot have shift value for float");
+      return nullptr;
+    }
+    return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
+  }
+
+  if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
+    auto mul =
+        rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
+    auto constant =
+        rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
+    return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
+                                                     constant);
+  }
+
+  // tosa::NegateOp
+  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>()) {
+    auto constant =
+        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, -1));
+    return rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], constant);
+  }
+
+  if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
+
   // tosa::BitwiseAndOp
   if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
@@ -67,6 +125,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
 
+  // tosa::RsqrtOp
+  if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
+
   // tosa::LogOp
   if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
@@ -75,13 +137,6 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
 
-  // tosa::SubOp
-  if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
-    return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
-
-  if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
-    return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
-
   // tosa::TanhOp
   if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
@@ -104,6 +159,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
                                          args[1]);
 
+  // tosa::SelectOp
+  if (isa<tosa::SelectOp>(op)) {
+    elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
+    if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
+      return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
+  }
+
   // tosa::MaximumOp
   if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
@@ -138,6 +200,44 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
 
+  // tosa::ClampOp
+  if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
+    auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
+                                                 op->getAttr("min_fp"));
+    auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
+                                                 op->getAttr("max_fp"));
+    return clampHelper<mlir::CmpFOp>(op, args, min, max, CmpFPredicate::OLT,
+                                     rewriter);
+  }
+
+  if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
+    auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
+                                                    rewriter);
+    auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
+                                                    rewriter);
+    return clampHelper<mlir::CmpIOp>(op, args, min, max, CmpIPredicate::slt,
+                                     rewriter);
+  }
+
+  // tosa::ReluNOp
+  if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
+    auto zero =
+        rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
+    auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
+                                               op->getAttr("max_fp"));
+    return clampHelper<mlir::CmpFOp>(op, args, zero, n, CmpFPredicate::OLT,
+                                     rewriter);
+  }
+
+  if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
+    auto zero =
+        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+    auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
+                                                  rewriter);
+    return clampHelper<mlir::CmpIOp>(op, args, zero, n, CmpIPredicate::slt,
+                                     rewriter);
+  }
+
   (void)rewriter.notifyMatchFailure(
       op, "unhandled op for linalg body calculation for elementwise op");
   return nullptr;
@@ -245,16 +345,19 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
     MLIRContext *context, OwningRewritePatternList *patterns) {
   patterns->insert<
       PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
-      PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::LogOp>,
-      PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
-      PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
+      PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
+      PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
+      PointwiseConverter<tosa::LogOp>, PointwiseConverter<tosa::ExpOp>,
+      PointwiseConverter<tosa::AbsOp>, PointwiseConverter<tosa::TanhOp>,
+      PointwiseConverter<tosa::BitwiseAndOp>,
       PointwiseConverter<tosa::BitwiseOrOp>,
       PointwiseConverter<tosa::BitwiseXorOp>,
       PointwiseConverter<tosa::LogicalLeftShiftOp>,
       PointwiseConverter<tosa::LogicalRightShiftOp>,
-      PointwiseConverter<tosa::GreaterOp>,
+      PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
       PointwiseConverter<tosa::GreaterEqualOp>,
       PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
-      PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>>(
+      PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
+      PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>>(
       context);
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 8963544838e1..022421459d16 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -116,43 +116,69 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: subf
   %3 = "tosa.sub"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
+  // CHECK: linalg.generic
+  // CHECK: mulf
+  %4 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: negf
+  %5 = "tosa.negate"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
   // CHECK: linalg.generic
   // CHECK: pow
-  %4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %6 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: rsqrt
+  %7 = "tosa.rsqrt"(%1) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: log
-  %5 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+  %8 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: exp
-  %6 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+  %9 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
-  %7 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+  %10 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
-  %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+  %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: select
+  %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
   // CHECK: select
-  %9 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
   // CHECK: select
-  %10 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: ceil
-  %11 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: floor
-  %12 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpf
+  // CHECK: select
+  %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpf
+  // CHECK: select
+  %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
   return
 }
@@ -169,44 +195,65 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: subi
   %1 = "tosa.sub"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: linalg.generic
+  // CHECK: muli
+  %2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: muli
+  %3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+
   // CHECK: linalg.generic
   // CHECK: and
-  %2 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: or
-  %3 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: xor
-  %4 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: shift_left
-  %5 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: shift_right_unsigned
-  %6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %7 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: select
+  %11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %9 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %10 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: linalg.generic
+  // CHECK: cmpi
+  // CHECK: select
+  %14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpi
+  // CHECK: select
+  %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   return
 }


        


More information about the Mlir-commits mailing list