[Mlir-commits] [mlir] cb3542e - [MLIR][TOSA] Added lowerings for Reduce operations to Linalg

Rob Suderman llvmlistbot at llvm.org
Mon Mar 8 11:03:48 PST 2021


Author: Rob Suderman
Date: 2021-03-08T10:57:19-08:00
New Revision: cb3542e1ca36d3deb49103336bbd409bc87b2177

URL: https://github.com/llvm/llvm-project/commit/cb3542e1ca36d3deb49103336bbd409bc87b2177
DIFF: https://github.com/llvm/llvm-project/commit/cb3542e1ca36d3deb49103336bbd409bc87b2177.diff

LOG: [MLIR][TOSA] Added lowerings for Reduce operations to Linalg

Lowerings for min, max, prod, and sum reduction operations on int and float
values. This includes reduction tests for both cases.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D97893

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ece9380845b6..2fe4aa31e482 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -269,15 +269,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   SmallVector<Type> opResultTypes;
   SmallVector<Value> initTensors;
   for (auto result : results) {
-    auto resultType = result.getType().template cast<ShapedType>();
-    if (!resultType.hasStaticShape())
+    auto resultTy = result.getType().template cast<ShapedType>();
+    if (!resultTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
           operation,
           "tosa to linalg conversion expects statically shaped tensors");
 
     initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
-        loc, ArrayRef<Value>({}), resultType.getShape(),
-        resultType.getElementType()));
+        loc, ArrayRef<Value>({}), resultTy.getShape(),
+        resultTy.getElementType()));
     opResultTypes.push_back(result.getType());
   }
 
@@ -330,6 +330,152 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   return success();
 }
 
+// Returns the constant initial value for a given reduction operation. The
+// attribute type varies depending on the element type required.
+static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
+                                               PatternRewriter &rewriter) {
+  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.getFloatAttr(elementTy, 0.0);
+
+  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
+    return rewriter.getIntegerAttr(elementTy, 0);
+
+  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.getFloatAttr(elementTy, 1.0);
+
+  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
+    return rewriter.getIntegerAttr(elementTy, 1);
+
+  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.getFloatAttr(
+        elementTy, APFloat::getLargest(
+                       elementTy.cast<FloatType>().getFloatSemantics(), false));
+
+  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
+    return rewriter.getIntegerAttr(
+        elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
+
+  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.getFloatAttr(
+        elementTy, APFloat::getLargest(
+                       elementTy.cast<FloatType>().getFloatSemantics(), true));
+
+  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
+    return rewriter.getIntegerAttr(
+        elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
+
+  return {};
+}
+
+// Creates the body calculation for a reduction. The operations vary depending
+// on the input type.
+static Value createLinalgBodyCalculationForReduceOp(Operation *op,
+                                                    ValueRange args,
+                                                    Type elementTy,
+                                                    PatternRewriter &rewriter) {
+  Location loc = op->getLoc();
+  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
+    return rewriter.create<AddFOp>(loc, args);
+  }
+
+  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
+    return rewriter.create<AddIOp>(loc, args);
+  }
+
+  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
+    return rewriter.create<MulFOp>(loc, args);
+  }
+
+  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
+    return rewriter.create<MulIOp>(loc, args);
+  }
+
+  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
+    auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
+    auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
+    auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
+    auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
+                                                   args[0], args[1]);
+    return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
+  }
+
+  return {};
+}
+
+// Performs the match and rewrite for reduction operations. This includes
+// declaring a correctly sized initial value, and the linalg.generic operation
+// that reduces across the specified axis.
+static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
+                                                 PatternRewriter &rewriter) {
+  auto loc = op->getLoc();
+  auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
+  auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
+  auto elementTy = resultTy.getElementType();
+  Value input = op->getOperand(0);
+
+  // First fill the output buffer with the init value.
+  auto initTensor = rewriter
+                        .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
+                                                      resultTy.getShape(),
+                                                      resultTy.getElementType())
+                        .result();
+
+  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+  if (!fillValueAttr)
+    return rewriter.notifyMatchFailure(
+        op, "No initial value found for reduction operation");
+
+  auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
+  auto filledTensor =
+      rewriter.create<linalg::FillOp>(loc, initTensor, fillValue).result();
+
+  SmallVector<AffineExpr, 2> srcExprs;
+  SmallVector<AffineExpr, 2> dstExprs;
+  SmallVector<StringRef, 4> iteratorTypes;
+  for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
+    srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
+
+    iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
+                                      : getParallelIteratorTypeName());
+    if (axis != i)
+      dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
+  }
+
+  bool didEncounterError = false;
+  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
+  auto linalgOp = rewriter.create<linalg::GenericOp>(
+      loc, resultTy, input, filledTensor, maps, iteratorTypes,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
+        auto result = createLinalgBodyCalculationForReduceOp(
+            op, blockArgs, elementTy, rewriter);
+        if (result)
+          didEncounterError = true;
+
+        nestedBuilder.create<linalg::YieldOp>(loc, result);
+      });
+
+  if (!didEncounterError)
+    return failure();
+
+  rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
+  return success();
+}
+
 namespace {
 
 template <typename SrcOp>
@@ -500,6 +646,17 @@ class IdentityNConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
+template <typename SrcOp>
+class ReduceConverter : public OpRewritePattern<SrcOp> {
+public:
+  using OpRewritePattern<SrcOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SrcOp reduceOp,
+                                PatternRewriter &rewriter) const final {
+    return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -521,6 +678,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
       PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
       IdentityNConverter<tosa::IdentityOp>,
-      IdentityNConverter<tosa::IdentityNOp>,
-      ReshapeOpConverter, TransposeConverter>(context);
+      IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
+      ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
+      ReduceConverter<tosa::ReduceProdOp>, ReshapeOpConverter,
+      TransposeConverter>(context);
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a39c722c177f..d1868e7683ce 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -335,3 +335,101 @@ func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
   %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
   return
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL: @reduce_float
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
+func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
+  // CHECK: [[CST0:%.+]] = constant 0.0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>)
+  // CHECK: ^bb0(%arg1: f32, %arg2: f32)
+  // CHECK:   [[RES:%.+]] = addf %arg1, %arg2 : f32
+  // CHECK:   linalg.yield [[RES]] : f32
+  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
+
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
+  // CHECK: [[CST0:%.+]] = constant 0.0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>)
+  // CHECK: ^bb0(%arg1: f32, %arg2: f32)
+  // CHECK:   [[RES:%.+]] = addf %arg1, %arg2 : f32
+  // CHECK:   linalg.yield [[RES]] : f32
+  %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5xf32>
+
+  // CHECK: constant 1.0
+  // CHECK: linalg.fill
+  // CHECK: linalg.generic
+  // CHECK: mulf
+  %2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
+
+  // CHECK: constant 3.40282347E+38 : f32
+  // CHECK: linalg.fill
+  // CHECK: linalg.generic
+  // CHECK: cmpf olt
+  // CHECK: select
+  %3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
+
+  // CHECK: constant -3.40282347E+38 : f32
+  // CHECK: linalg.fill
+  // CHECK: linalg.generic
+  // CHECK: cmpf ogt
+  // CHECK: select
+  %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL: @reduce_int
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32>
+func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
+  // CHECK: [[CST0:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>)
+  // CHECK: ^bb0(%arg1: i32, %arg2: i32)
+  // CHECK:   [[RES:%.+]] = addi %arg1, %arg2 : i32
+  // CHECK:   linalg.yield [[RES]] : i32
+  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
+
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
+  // CHECK: [[CST0:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>)
+  // CHECK: ^bb0(%arg1: i32, %arg2: i32)
+  // CHECK:   [[RES:%.+]] = addi %arg1, %arg2 : i32
+  // CHECK:   linalg.yield [[RES]] : i32
+  %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5xi32>
+
+  // CHECK: constant 1
+  // CHECK: linalg.fill
+  // CHECK: linalg.generic
+  // CHECK: muli
+  %2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
+
+  // CHECK: constant 2147483647 : i32
+  // CHECK: linalg.fill
+  // CHECK: linalg.generic
+  // CHECK: cmpi slt
+  // CHECK: select
+  %3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
+
+  // CHECK: constant -2147483648 : i32
+  // CHECK: linalg.fill
+  // CHECK: linalg.generic
+  // CHECK: cmpi sgt
+  // CHECK: select
+  %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
+  return
+}


        


More information about the Mlir-commits mailing list