[Mlir-commits] [mlir] [mlir][tosa] Improve lowering to tosa.fully_connected (PR #73049)
Spenser Bauman
llvmlistbot at llvm.org
Fri Dec 1 06:17:51 PST 2023
https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/73049
>From ec8cd7a189c183b9159aa601880239acf0b0be04 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Tue, 21 Nov 2023 17:02:20 -0500
Subject: [PATCH] [mlir][tosa] Improve lowering to tosa.fully_connected
The current lowering of tosa.fully_connected produces
a linalg.matmul followed by a linalg.generic to add the bias.
The IR looks like the following:
%init = tensor.empty()
%zero = linalg.fill ins(0 : f32) outs(%init)
%prod = linalg.matmul ins(%A, %B) outs(%zero)
%initB = tensor.empty()
%result = linalg.generic ins(%prod, %bias) outs(%initB)
This has two down sides:
1. The tensor.empty operations typically result in additional
allocations after bufferization
2. There is a redundant traversal of the data to add the bias to the
matrix product.
This extra work can be avoided by leveraging the out-param of
linalg.matmul. The new IR sequence is:
%init = tensor.empty()
%broadcast = linalg.broadcast ins(%bias) outs(%init)
%prod = linalg.matmul ins(%A, %B) outs(%broadcast)
In my experiments, this eliminates one loop and one allocation (post
bufferization) from the generated code.
---
.../TosaToLinalg/TosaToLinalgNamed.cpp | 91 +++++++++++--------
.../TosaToLinalg/tosa-to-linalg-named.mlir | 84 +++++++++--------
.../Tosa/CPU/test-fully-connected.mlir | 36 ++++++++
3 files changed, 128 insertions(+), 83 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 9e374be534985e5..b30651976eeb939 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -85,6 +85,49 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
.getResult(0);
}
+// Broadcast the source value to all the outer dimensions of the result value.
+// If required, the element type is expanded using an arith.extsi operation.
+static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
+ Location loc, Value source,
+ Value result) {
+ ShapedType resultTy = cast<ShapedType>(result.getType());
+ ShapedType sourceTy = cast<ShapedType>(source.getType());
+ int64_t resultRank = resultTy.getRank();
+ int64_t sourceRank = sourceTy.getRank();
+
+ // The source tensor is broadcast to all the outer dimensions of the
+ // result tensor.
+ SmallVector<AffineExpr> sourceDims;
+ for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+ auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
+ sourceDims.push_back(expr);
+ }
+
+ // Creating maps for the input and output of the broacast-like generic op.
+ SmallVector<AffineMap, 2> indexingMaps = {
+ // Broadcast the last dimension of the bias to all output dimensions.
+ AffineMap::get(/*dimCount=*/resultRank,
+ /*symbolCount=*/0, sourceDims, rewriter.getContext()),
+
+ // Output indexing map.
+ rewriter.getMultiDimIdentityMap(resultRank)};
+
+ // Build the broadcast-like operation as a linalg.generic.
+ return rewriter
+ .create<linalg::GenericOp>(
+ loc, resultTy, ValueRange({source}), result, indexingMaps,
+ getNParallelLoopsAttrs(resultTy.getRank()),
+ [](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+ }
+ builder.create<linalg::YieldOp>(loc, biasVal);
+ })
+ .getResult(0);
+}
+
static mlir::Value reifyConstantDim(int64_t attr,
ImplicitLocOpBuilder &builder) {
return builder.createOrFold<arith::IndexCastOp>(
@@ -618,28 +661,6 @@ class FullyConnectedConverter
SmallVector<Value> filteredDims = condenseValues(dynDims);
- // Creating maps for the output of MatMul and the bias
- SmallVector<AffineMap, 4> indexingMaps;
-
- // Broadcast the bias.
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
- {rewriter.getAffineDimExpr(1)},
- rewriter.getContext()));
-
- indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
- indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
-
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
-
- // When quantized, the input elemeny type is not the same as the output
- auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
- Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
- .result();
-
SmallVector<int64_t> permutation{1, 0};
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
Value permutationValue =
@@ -655,26 +676,17 @@ class FullyConnectedConverter
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputTy.getShape(), outputETy, filteredDims);
+ Value broadcastBias =
+ linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+
if (!op.getQuantizationInfo()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
- ValueRange{input, transposedWeight}, zeroTensor)
+ ValueRange{input, transposedWeight}, broadcastBias)
->getResult(0);
- Value result =
- rewriter
- .create<linalg::GenericOp>(
- loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor,
- indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- Value added = nestedBuilder.create<arith::AddFOp>(
- loc, args[0], args[1]);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
- })
- .getResult(0);
- rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, matmul);
return success();
}
@@ -688,11 +700,10 @@ class FullyConnectedConverter
.create<linalg::QuantizedMatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight, inputZp, outputZp},
- zeroTensor)
+ broadcastBias)
->getResult(0);
- Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, matmul,
- biasEmptyTensor, indexingMaps);
- rewriter.replaceOp(op, result);
+
+ rewriter.replaceOp(op, matmul);
return success();
}
};
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 4edc75331932803..bbdd1bad799865d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -82,22 +82,21 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
// -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @fully_connected
func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
- // CHECK: [[INITT:%.+]] = tensor.empty()
- // CHECK: [[ZERO:%.+]] = arith.constant 0
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
- // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
- // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
- // CHECK: [[INITB:%.+]] = tensor.empty()
- // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
- // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
- // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
- // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
- // CHECK: linalg.yield [[ADD]] : f32
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+ // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
+
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+ // CHECK: linalg.yield %[[IN]] : f32
+ // CHECK: } -> tensor<5x6xf32>
+
+ // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<5x6xf32>) -> tensor<5x6xf32>
%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<5x6xf32>
return %0 : tensor<5x6xf32>
@@ -105,48 +104,47 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
// -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @quantized_fully_connected
func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
- // CHECK: [[INITT:%.+]] = tensor.empty()
- // CHECK: [[ZERO:%.+]] = arith.constant 0
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
- // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
- // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
- // CHECK: [[INITB:%.+]] = tensor.empty()
- // CHECK: [[ONE:%.+]] = arith.constant 1
- // CHECK: [[TWO:%.+]] = arith.constant 2
- // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
- // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
- // CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
- // CHECK: [[ADD:%.+]] = arith.addi
- // CHECK: linalg.yield [[ADD]] : i32
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+ // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
+
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
+ // CHECK: linalg.yield %[[IN]] : i32
+ // CHECK: } -> tensor<5x6xi32>
+
+ // CHECK: %[[C1:.+]] = arith.constant 1 : i32
+ // CHECK: %[[C2:.+]] = arith.constant 2 : i32
+ // CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>
+
%0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
return %0 : tensor<5x6xi32>
}
// -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @fully_connected_dyn
func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
- // CHECK: %[[C0:.+]] = arith.constant 0
- // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
- // CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]])
- // CHECK: %[[ZERO:.+]] = arith.constant 0
- // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
- // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]]
- // CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]])
- // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
- // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
- // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
- // CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
- // CHECK: linalg.yield %[[ADD]] : f32
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+ // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
+
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+ // CHECK: linalg.yield %[[IN]] : f32
+ // CHECK: } -> tensor<?x6xf32>
+
+ // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<?x6xf32>) -> tensor<?x6xf32>
%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<?x6xf32>
return %0 : tensor<?x6xf32>
diff --git a/mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir b/mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir
new file mode 100644
index 000000000000000..bf178c826574e4d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-arith))" | \
+// RUN: mlir-opt -one-shot-bufferize -func-bufferize -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils \
+// RUN: | FileCheck %s
+
+func.func private @printMemrefF32(tensor<*xf32>)
+
+func.func @main() {
+ %A = arith.constant dense<[
+ [8.0, 1.0, 6.0],
+ [3.0, 5.0, 7.0],
+ [4.0, 9.0, 2.0]
+ ]> : tensor<3x3xf32>
+
+ %B = arith.constant dense<[
+ [1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0]
+ ]> : tensor<3x3xf32>
+
+ %C = arith.constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32>
+
+ %result = tosa.fully_connected %A, %B, %C : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32>
+
+ %result_unranked = tensor.cast %result : tensor<3x3xf32> to tensor<*xf32>
+ call @printMemrefF32(%result_unranked) : (tensor<*xf32>) -> ()
+ return
+}
+
+// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] data =
+// CHECK-NEXT: [
+// CHECK-SAME: [15, 16, 17]
+// CHECK-NEXT: [15, 16, 17]
+// CHECK-NEXT: [15, 16, 17]
+// CHECK-SAME: ]
More information about the Mlir-commits
mailing list