[Mlir-commits] [mlir] 21724dd - [MLIR][TOSA] Comparison based elementwise operations for tosa-to-linalg

Rob Suderman llvmlistbot at llvm.org
Mon Feb 1 21:44:40 PST 2021


Author: natashaknk
Date: 2021-02-01T21:37:52-08:00
New Revision: 21724ddcb7033cb010d57ff1a2d593cd70d462f5

URL: https://github.com/llvm/llvm-project/commit/21724ddcb7033cb010d57ff1a2d593cd70d462f5
DIFF: https://github.com/llvm/llvm-project/commit/21724ddcb7033cb010d57ff1a2d593cd70d462f5.diff

LOG: [MLIR][TOSA] Comparison based elementwise operations for tosa-to-linalg

Comitted log, exp, maximum, minimum, comparison, ceil and floor conversions from TOSA to LinAlg. Support for signless integer and floating point.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D95839

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6ad5a256f98a..dd4bb2a3d016 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -29,7 +29,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                             PatternRewriter &rewriter) {
   Location loc = op->getLoc();
   auto elementTy =
-      op->getResult(0).getType().cast<ShapedType>().getElementType();
+      op->getOperand(0).getType().cast<ShapedType>().getElementType();
 
   // tosa::AbsOp
   if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
@@ -66,6 +66,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::PowFOp>(loc, resultTypes, args);
 
+  // tosa::LogOp
+  if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::LogOp>(loc, resultTypes, args);
+
+  // tosa::ExpOp
+  if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::ExpOp>(loc, resultTypes, args);
+
   // tosa::SubOp
   if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
@@ -77,6 +85,58 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
     return rewriter.create<mlir::TanhOp>(loc, resultTypes, args);
 
+  // tosa::GreaterOp
+  if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
+                                         args[1]);
+
+  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
+    return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
+                                         args[1]);
+
+  // tosa::GreaterEqualOp
+  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
+                                         args[1]);
+
+  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
+    return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
+                                         args[1]);
+
+  // tosa::MaximumOp
+  if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
+    auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
+    auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  // tosa::MinimumOp
+  if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
+    auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
+    auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  // tosa::CeilOp
+  if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
+
+  // tosa::FloorOp
+  if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
+
   rewriter.notifyMatchFailure(
       op, "unhandled op for linalg body calculation for elementwise op");
   return nullptr;
@@ -94,19 +154,21 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
 
   // For now require no broadcasting. Consider making it support broadcasting
   // operations.
-  Type uniqueTy = operation->getOperand(0).getType();
+  Type uniqueInTy = operation->getOperand(0).getType();
   bool allInputTypesEqual =
       llvm::all_of(operation->getOperandTypes(),
-                   [&](Type operandTy) { return operandTy == uniqueTy; });
+                   [&](Type operandTy) { return operandTy == uniqueInTy; });
   if (!allInputTypesEqual)
     return rewriter.notifyMatchFailure(operation,
                                        "All operands must have the same type");
-  bool allResultTypesEqual =
-      llvm::all_of(operation->getResultTypes(),
-                   [&](Type resultTy) { return resultTy == uniqueTy; });
-  if (!allResultTypesEqual)
+  bool resultAndInputShapeEqual =
+      llvm::all_of(operation->getResultTypes(), [&](Type resultTy) {
+        return resultTy.cast<ShapedType>().getShape() == t0.getShape();
+      });
+
+  if (!resultAndInputShapeEqual)
     return rewriter.notifyMatchFailure(
-        operation, "All results must have the same type as the input");
+        operation, "All results must have the same shape as the input");
 
   // Construct the indexing maps needed for linalg.generic ops.
   SmallVector<Type> bodyArgTypes;
@@ -179,10 +241,16 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
     MLIRContext *context, OwningRewritePatternList *patterns) {
   patterns->insert<
       PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
-      PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::AbsOp>,
+      PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::LogOp>,
+      PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
       PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
       PointwiseConverter<tosa::BitwiseOrOp>,
       PointwiseConverter<tosa::BitwiseXorOp>,
       PointwiseConverter<tosa::LogicalLeftShiftOp>,
-      PointwiseConverter<tosa::LogicalRightShiftOp>>(context);
+      PointwiseConverter<tosa::LogicalRightShiftOp>,
+      PointwiseConverter<tosa::GreaterOp>,
+      PointwiseConverter<tosa::GreaterEqualOp>,
+      PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
+      PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>>(
+      context);
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6f849f7030d3..e416246a19a4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -100,6 +100,41 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: linalg.generic
   // CHECK: pow
   %4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: log
+  %5 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: exp
+  %6 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpf
+  %7 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpf
+  %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpf
+  // CHECK: select
+  %9 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpf
+  // CHECK: select
+  %10 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: ceil
+  %11 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: floor
+  %12 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
   return
 }
 
@@ -135,6 +170,25 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: shift_right_unsigned
   %6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: linalg.generic
+  // CHECK: cmpi
+  %7 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpi
+  %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpi
+  // CHECK: select
+  %9 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: cmpi
+  // CHECK: select
+  %10 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+
   return
 }
 


        


More information about the Mlir-commits mailing list