[Mlir-commits] [mlir] [mlir][tosa] Improve lowering to tosa.fully_connected (PR #73049)

Spenser Bauman llvmlistbot at llvm.org
Tue Nov 21 14:46:30 PST 2023

https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/73049

The current lowering of tosa.fully_connected produces a linalg.matmul followed by a linalg.generic to add the bias. The IR looks like the following:

    %init = tensor.empty()
    %zero = linalg.fill ins(0 : f32) outs(%init)
    %prod = linalg.matmul ins(%A, %B) outs(%zero)

    // Add the bias
    %initB = tensor.empty()
    %result = linalg.generic ins(%prod, %bias) outs(%initB) {
       // add bias and product

This has two down sides:

1. The tensor.empty operations typically result in additional allocations after bufferization
2. There is a redundant traversal of the data to add the bias to the matrix product.

This extra work can be avoided by leveraging the out-param of linalg.matmul. The new IR sequence is:

    %init = tensor.empty()
    %broadcast = linalg.broadcast ins(%bias) outs(%init)
    %prod = linalg.matmul ins(%A, %B) outs(%broadcast)

In my experiments, this eliminates one loop and one allocation (post bufferization) from the generated code.

>From 43bf464273de8efd08bf81525a1d3eb095839e5d Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Tue, 21 Nov 2023 17:02:20 -0500
Subject: [PATCH] [mlir][tosa] Improve lowering to tosa.fully_connected

The current lowering of tosa.fully_connected produces
a linalg.matmul followed by a linalg.generic to add the bias.
The IR looks like the following:

    %init = tensor.empty()
    %zero = linalg.fill ins(0 : f32) outs(%init)
    %prod = linalg.matmul ins(%A, %B) outs(%zero)

    %initB = tensor.empty()
    %result = linalg.generic ins(%prod, %bias) outs(%initB)

This has two down sides:

1. The tensor.empty operations typically result in additional
   allocations after bufferization
2. There is a redundant traversal of the data to add the bias to the
   matrix product.

This extra work can be avoided by leveraging the out-param of
linalg.matmul. The new IR sequence is:

    %init = tensor.empty()
    %broadcast = linalg.broadcast ins(%bias) outs(%init)
    %prod = linalg.matmul ins(%A, %B) outs(%broadcast)

In my experiments, this eliminates one loop and one allocation (post
bufferization) from the generated code.
 .../TosaToLinalg/TosaToLinalgNamed.cpp        | 38 +++--------
 .../TosaToLinalg/tosa-to-linalg-named.mlir    | 66 ++++++-------------
 2 files changed, 29 insertions(+), 75 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 99a65f63038a43f..b9a7b778ce4017d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -632,17 +632,6 @@ class FullyConnectedConverter
-    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
-    // When quantized, the input elemeny type is not the same as the output
-    auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
-    Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
-    Value zeroTensor = rewriter
-                           .create<linalg::FillOp>(loc, ValueRange{zero},
-                                                   ValueRange{emptyTensor})
-                           .result();
     SmallVector<int64_t> permutation{1, 0};
     auto permutationAttr = rewriter.getI64TensorAttr(permutation);
     Value permutationValue =
@@ -658,26 +647,18 @@ class FullyConnectedConverter
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, outputTy.getShape(), outputETy, filteredDims);
+    auto broadcastDims = DenseI64ArrayAttr::get(getContext(), {0});
+    Value biasInitTensor = rewriter.create<linalg::BroadcastOp>(
+        loc, bias, biasEmptyTensor, broadcastDims)->getResult(0);
     if (!op.getQuantizationInfo()) {
       Value matmul = rewriter
                              loc, TypeRange{op.getType()},
-                             ValueRange{input, transposedWeight}, zeroTensor)
+                             ValueRange{input, transposedWeight}, biasInitTensor)
-      Value result =
-          rewriter
-              .create<linalg::GenericOp>(
-                  loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor,
-                  indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
-                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                      ValueRange args) {
-                    Value added = nestedBuilder.create<arith::AddFOp>(
-                        loc, args[0], args[1]);
-                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
-                  })
-              .getResult(0);
-      rewriter.replaceOp(op, result);
+      rewriter.replaceOp(op, matmul);
       return success();
@@ -691,11 +672,10 @@ class FullyConnectedConverter
                 loc, TypeRange{op.getType()},
                 ValueRange{input, transposedWeight, inputZp, outputZp},
-                zeroTensor)
+                biasInitTensor)
-    Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, matmul,
-                                              biasEmptyTensor, indexingMaps);
-    rewriter.replaceOp(op, result);
+    rewriter.replaceOp(op, matmul);
     return success();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1cf7c8dee606899..3b6d574b73b1ab6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -68,22 +68,13 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
 // -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: @fully_connected
 func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
-  // CHECK: [[INITT:%.+]] = tensor.empty()
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
-  // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
-  // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
-  // CHECK: [[INITB:%.+]] = tensor.empty()
-  // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
-  // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
-  // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
-  // CHECK:   [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
-  // CHECK:   linalg.yield [[ADD]] : f32
+  // %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+  // %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+  // %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
+  // %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) dimensions = [0] 
+  // linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCASTED]] : tensor<5x6xf32>) -> tensor<5x6xf32>
   %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<5x6xf32>
   return %0 : tensor<5x6xf32>
@@ -91,48 +82,31 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
 // -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: @quantized_fully_connected
 func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
-  // CHECK: [[INITT:%.+]] = tensor.empty()
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
-  // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
-  // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
-  // CHECK: [[INITB:%.+]] = tensor.empty()
-  // CHECK: [[ONE:%.+]] = arith.constant 1
-  // CHECK: [[TWO:%.+]] = arith.constant 2
-  // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
-  // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
-  // CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
-  // CHECK:   [[ADD:%.+]] = arith.addi
-  // CHECK:   linalg.yield [[ADD]] : i32
+  // %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+  // %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
+  // %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
+  // %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) dimensions = [0]
+  // %[[C1:.+]] = arith.constant 1 : i32
+  // %[[C2:.+]] = arith.constant 2 : i32
+  // linalg.quantized_matmul ins(%arg0, %[[BROADCASTED]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCASTED]]) : tensor<5x6xi32>) -> tensor<5x6xi32>
   %0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
   return %0 : tensor<5x6xi32>
 // -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: @fully_connected_dyn
 func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
-  // CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]])
-  // CHECK: %[[ZERO:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
-  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
-  // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]]
-  // CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]])
-  // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
-  // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
-  // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
-  // CHECK:   %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
-  // CHECK:   linalg.yield %[[ADD]] : f32
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
+  // CHECK: %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) dimensions = [0]
+  // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCASTED]] : tensor<?x6xf32>) -> tensor<?x6xf32>
   %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<?x6xf32>
   return %0 : tensor<?x6xf32>

More information about the Mlir-commits mailing list