[Mlir-commits] [mlir] e20911b - [mlir][tosa] Add tosa.matmul and tosa.fully_connected lowering

Rob Suderman llvmlistbot at llvm.org
Tue Mar 23 13:11:23 PDT 2021


Author: natashaknk
Date: 2021-03-23T13:09:53-07:00
New Revision: e20911b5c0360882ea166886af75c0038310f6e5

URL: https://github.com/llvm/llvm-project/commit/e20911b5c0360882ea166886af75c0038310f6e5
DIFF: https://github.com/llvm/llvm-project/commit/e20911b5c0360882ea166886af75c0038310f6e5.diff

LOG: [mlir][tosa] Add tosa.matmul and tosa.fully_connected lowering

Adds lowerings for matmul and fully_connected. Only supports 2D tensors for inputs and weights, and 1D tensors for bias.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D99211

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e0117e0f694fa..fe1336fb47319 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -612,6 +612,84 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
+class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
+public:
+  using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::MatMulOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    tosa::MatMulOp::Adaptor adaptor(args);
+
+    Location loc = op.getLoc();
+
+    auto outputTy = op.getType().cast<ShapedType>();
+    auto outputElementTy = outputTy.getElementType();
+    auto zero_attr = rewriter.getZeroAttr(outputElementTy);
+    Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
+    auto initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, outputTy.getShape(), outputTy.getElementType());
+    Value zeroTensor =
+        rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
+    rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
+        op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
+        ValueRange{zeroTensor});
+    return success();
+  }
+};
+
+class FullyConnectedConverter
+    : public OpConversionPattern<tosa::FullyConnectedOp> {
+public:
+  using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    tosa::FullyConnectedOp::Adaptor adaptor(args);
+
+    Location loc = op.getLoc();
+    auto outputTy = op.getType().cast<ShapedType>();
+    auto biasTy = op->getOperand(2).getType().cast<ShapedType>();
+
+    // Reshaping the bias from n to [1, n] for broadcasting
+    SmallVector<int64_t> biasShapeReshaped;
+    biasShapeReshaped.push_back(1);
+    biasShapeReshaped.push_back(biasTy.getShape()[0]);
+
+    RankedTensorType reshapedBias =
+        RankedTensorType::get(biasShapeReshaped, outputTy.getElementType());
+    auto reshapeResult =
+        rewriter.create<tosa::ReshapeOp>(loc, reshapedBias, args[2])
+            ->getResult(0);
+
+    // Creating maps for the output of MatMul and the bias
+    SmallVector<AffineMap, 4> indexingMaps;
+    indexingMaps.push_back(createAffineMapForType(reshapedBias, rewriter));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
+
+    auto initTensor =
+        rewriter
+            .create<linalg::InitTensorOp>(loc, outputTy.getShape(),
+                                          outputTy.getElementType())
+            ->getResults();
+
+    auto linalgOp =
+        rewriter
+            .create<linalg::GenericOp>(
+                loc, outputTy, reshapeResult, initTensor, indexingMaps,
+                getNParallelLoopsAttrs(outputTy.getRank()),
+                [&](OpBuilder &nested_builder, Location nested_loc,
+                    ValueRange args) {
+                  nested_builder.create<linalg::YieldOp>(loc, *args.begin());
+                })
+            ->getResults();
+
+    rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
+        op, TypeRange{op.getType()},
+        ValueRange{adaptor.input(), adaptor.weight()}, linalgOp);
+    return success();
+  }
+};
+
 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
 public:
   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
@@ -1041,6 +1119,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
       ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
       ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
-      RescaleConverter, ReverseConverter, TransposeConverter>(
-      patterns->getContext());
+      RescaleConverter, ReverseConverter, TransposeConverter, MatMulConverter,
+      FullyConnectedConverter>(patterns->getContext());
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 33b82bc9e0fb3..2aaf6941bc7f2 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -639,3 +639,32 @@ func @reverse(%arg0: tensor<5x4xi32>) -> () {
 
   return
 }
+
+// -----
+
+
+// CHECK-LABEL: @matmul
+func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+  // CHECK: [[C0:%.+]] = constant 0
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
+  // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<5x6xf32>, f32 -> tensor<5x6xf32>
+  // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILLED]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x6xf32>)  -> (tensor<5x6xf32>)
+  return %0 : tensor<5x6xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
+
+// CHECK-LABEL: @fully_connected
+func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+  // CHECK: [[RS:%.+]] = linalg.tensor_reshape %arg2 [#[[$MAP0]]]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RS]] : tensor<1x6xf32>) outs([[INIT]] : tensor<5x6xf32>) {
+  // CHECK: ^bb0([[IN:%.+]]: f32, [[MULTIPLIER:%.+]]: f32):
+  // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>)  -> (tensor<5x6xf32>)
+  return %0 : tensor<5x6xf32>
+}


        


More information about the Mlir-commits mailing list