[Mlir-commits] [mlir] cc1ae54 - [tosa][mlir] Fix FullyConnected to correctly order dimensions
Rob Suderman
llvmlistbot at llvm.org
Tue Apr 27 17:33:22 PDT 2021
Author: Rob Suderman
Date: 2021-04-27T17:26:04-07:00
New Revision: cc1ae54ebcc4072ce19b610511f500f15c7acd8a
URL: https://github.com/llvm/llvm-project/commit/cc1ae54ebcc4072ce19b610511f500f15c7acd8a
DIFF: https://github.com/llvm/llvm-project/commit/cc1ae54ebcc4072ce19b610511f500f15c7acd8a.diff
LOG: [tosa][mlir] Fix FullyConnected to correctly order dimensions
MatMul and FullyConnected have transposed dimensions for the weights.
Also, removed uneeded tensor reshape for bias.
Differential Revision: https://reviews.llvm.org/D101220
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 51de267170ad..b31656467c48 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -745,26 +745,28 @@ class FullyConnectedConverter
LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
- tosa::FullyConnectedOp::Adaptor adaptor(args);
-
Location loc = op.getLoc();
auto outputTy = op.getType().cast<ShapedType>();
- auto biasTy = op->getOperand(2).getType().cast<ShapedType>();
+ auto input = op.input();
+ auto weight = op.weight();
+ auto bias = op.bias();
- // Reshaping the bias from n to [1, n] for broadcasting
- SmallVector<int64_t> biasShapeReshaped;
- biasShapeReshaped.push_back(1);
- biasShapeReshaped.push_back(biasTy.getShape()[0]);
+ auto weightTy = weight.getType().cast<ShapedType>();
+ auto biasTy = bias.getType().cast<ShapedType>();
- RankedTensorType reshapedBias =
- RankedTensorType::get(biasShapeReshaped, outputTy.getElementType());
- auto reshapeResult =
- rewriter.create<tosa::ReshapeOp>(loc, reshapedBias, args[2])
- ->getResult(0);
+ auto weightShape = weightTy.getShape();
+
+ if (op.quantization_info())
+ return failure();
// Creating maps for the output of MatMul and the bias
SmallVector<AffineMap, 4> indexingMaps;
- indexingMaps.push_back(createAffineMapForType(reshapedBias, rewriter));
+
+ // Broadcast the bias.
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
+ {rewriter.getAffineDimExpr(1)},
+ rewriter.getContext()));
+
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
auto initTensor =
@@ -776,7 +778,7 @@ class FullyConnectedConverter
auto linalgOp =
rewriter
.create<linalg::GenericOp>(
- loc, outputTy, reshapeResult, initTensor, indexingMaps,
+ loc, outputTy, bias, initTensor, indexingMaps,
getNParallelLoopsAttrs(outputTy.getRank()),
[&](OpBuilder &nested_builder, Location nested_loc,
ValueRange args) {
@@ -784,9 +786,21 @@ class FullyConnectedConverter
})
->getResults();
+ SmallVector<int64_t> permutation{1, 0};
+ auto permutationAttr = DenseIntElementsAttr::get(
+ RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
+ Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
+
+ SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
+ Type newWeightTy =
+ RankedTensorType::get(newWeightShape, biasTy.getElementType());
+
+ Value transposedWeight = rewriter.create<tosa::TransposeOp>(
+ loc, newWeightTy, weight, permutationValue);
+
rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
- op, TypeRange{op.getType()},
- ValueRange{adaptor.input(), adaptor.weight()}, linalgOp);
+ op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
+ linalgOp);
return success();
}
};
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3f9940e4a203..57d8c86fa25b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -756,17 +756,22 @@ func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32
// -----
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: @fully_connected
-func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
- // CHECK: [[RS:%.+]] = linalg.tensor_reshape %arg2 [#[[$MAP0]]]
- // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RS]] : tensor<1x6xf32>) outs([[INIT]] : tensor<5x6xf32>) {
- // CHECK: ^bb0([[IN:%.+]]: f32, [[MULTIPLIER:%.+]]: f32):
- // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32>
- %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>)
+func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+ // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
+ // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32):
+ // CHECK: linalg.yield [[IN]] : f32
+ // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6]
+ // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs([[INITT]]
+ // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32):
+ // CHECK: linalg.yield [[IN]] : f32
+ // CHECK: linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> (tensor<5x6xf32>)
return %0 : tensor<5x6xf32>
}
More information about the Mlir-commits
mailing list