[Mlir-commits] [mlir] e20911b - [mlir][tosa] Add tosa.matmul and tosa.fully_connected lowering
Rob Suderman
llvmlistbot at llvm.org
Tue Mar 23 13:11:23 PDT 2021
Author: natashaknk
Date: 2021-03-23T13:09:53-07:00
New Revision: e20911b5c0360882ea166886af75c0038310f6e5
URL: https://github.com/llvm/llvm-project/commit/e20911b5c0360882ea166886af75c0038310f6e5
DIFF: https://github.com/llvm/llvm-project/commit/e20911b5c0360882ea166886af75c0038310f6e5.diff
LOG: [mlir][tosa] Add tosa.matmul and tosa.fully_connected lowering
Adds lowerings for matmul and fully_connected. Only supports 2D tensors for inputs and weights, and 1D tensors for bias.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D99211
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e0117e0f694fa..fe1336fb47319 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -612,6 +612,84 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
}
};
+class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
+public:
+ using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tosa::MatMulOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const final {
+ tosa::MatMulOp::Adaptor adaptor(args);
+
+ Location loc = op.getLoc();
+
+ auto outputTy = op.getType().cast<ShapedType>();
+ auto outputElementTy = outputTy.getElementType();
+ auto zero_attr = rewriter.getZeroAttr(outputElementTy);
+ Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
+ auto initTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, outputTy.getShape(), outputTy.getElementType());
+ Value zeroTensor =
+ rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
+ rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
+ op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
+ ValueRange{zeroTensor});
+ return success();
+ }
+};
+
+class FullyConnectedConverter
+ : public OpConversionPattern<tosa::FullyConnectedOp> {
+public:
+ using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const final {
+ tosa::FullyConnectedOp::Adaptor adaptor(args);
+
+ Location loc = op.getLoc();
+ auto outputTy = op.getType().cast<ShapedType>();
+ auto biasTy = op->getOperand(2).getType().cast<ShapedType>();
+
+ // Reshaping the bias from n to [1, n] for broadcasting
+ SmallVector<int64_t> biasShapeReshaped;
+ biasShapeReshaped.push_back(1);
+ biasShapeReshaped.push_back(biasTy.getShape()[0]);
+
+ RankedTensorType reshapedBias =
+ RankedTensorType::get(biasShapeReshaped, outputTy.getElementType());
+ auto reshapeResult =
+ rewriter.create<tosa::ReshapeOp>(loc, reshapedBias, args[2])
+ ->getResult(0);
+
+ // Creating maps for the output of MatMul and the bias
+ SmallVector<AffineMap, 4> indexingMaps;
+ indexingMaps.push_back(createAffineMapForType(reshapedBias, rewriter));
+ indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
+
+ auto initTensor =
+ rewriter
+ .create<linalg::InitTensorOp>(loc, outputTy.getShape(),
+ outputTy.getElementType())
+ ->getResults();
+
+ auto linalgOp =
+ rewriter
+ .create<linalg::GenericOp>(
+ loc, outputTy, reshapeResult, initTensor, indexingMaps,
+ getNParallelLoopsAttrs(outputTy.getRank()),
+ [&](OpBuilder &nested_builder, Location nested_loc,
+ ValueRange args) {
+ nested_builder.create<linalg::YieldOp>(loc, *args.begin());
+ })
+ ->getResults();
+
+ rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
+ op, TypeRange{op.getType()},
+ ValueRange{adaptor.input(), adaptor.weight()}, linalgOp);
+ return success();
+ }
+};
+
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
@@ -1041,6 +1119,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
- RescaleConverter, ReverseConverter, TransposeConverter>(
- patterns->getContext());
+ RescaleConverter, ReverseConverter, TransposeConverter, MatMulConverter,
+ FullyConnectedConverter>(patterns->getContext());
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 33b82bc9e0fb3..2aaf6941bc7f2 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -639,3 +639,32 @@ func @reverse(%arg0: tensor<5x4xi32>) -> () {
return
}
+
+// -----
+
+
+// CHECK-LABEL: @matmul
+func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+ // CHECK: [[C0:%.+]] = constant 0
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
+ // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<5x6xf32>, f32 -> tensor<5x6xf32>
+ // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILLED]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x6xf32>) -> (tensor<5x6xf32>)
+ return %0 : tensor<5x6xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
+
+// CHECK-LABEL: @fully_connected
+func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+ // CHECK: [[RS:%.+]] = linalg.tensor_reshape %arg2 [#[[$MAP0]]]
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RS]] : tensor<1x6xf32>) outs([[INIT]] : tensor<5x6xf32>) {
+ // CHECK: ^bb0([[IN:%.+]]: f32, [[MULTIPLIER:%.+]]: f32):
+ // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>)
+ return %0 : tensor<5x6xf32>
+}
More information about the Mlir-commits
mailing list