[Mlir-commits] [mlir] [mlir] Add Scalar Broadcast TOSA Depthwise Conv (PR #110806)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 2 02:00:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

<details>
<summary>Changes</summary>

Support broadcasting of depthwise conv2d bias in tosa->linalg named lowering in the case that bias is a rank-1 tensor with exactly 1 element. In this case TOSA specifies the value should first be broadcast across the bias dimension and then across the result tensor.

Add `lit` tests for depthwise conv2d with scalar bias and for conv3d which was already supported but missing coverage.

---
Full diff: https://github.com/llvm/llvm-project/pull/110806.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+27-22) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+30) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index fe53b499674324..d537aef5791031 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -88,15 +88,14 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
       .getResult(0);
 }
 
-// Broadcast the source value to all the outer dimensions of the result value.
-// If required, the element type is expanded using an arith.extsi operation.
-static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
-                                                Location loc, Value source,
-                                                Value result) {
+// Construct the affine map that a linalg generic would use to broadcast the
+// source tensor into the shape of the result tensor.
+static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
+                                    Value result) {
   ShapedType resultTy = cast<ShapedType>(result.getType());
   ShapedType sourceTy = cast<ShapedType>(source.getType());
-  int64_t resultRank = resultTy.getRank();
-  int64_t sourceRank = sourceTy.getRank();
+  const int64_t resultRank = resultTy.getRank();
+  const int64_t sourceRank = sourceTy.getRank();
 
   // The source tensor is broadcast to all the outer dimensions of the
   // result tensor.
@@ -115,14 +114,21 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
     }
   }
 
-  // Creating maps for the input and output of the broacast-like generic op.
-  SmallVector<AffineMap, 2> indexingMaps = {
-      // Broadcast the last dimension of the bias to all output dimensions.
-      AffineMap::get(/*dimCount=*/resultRank,
-                     /*symbolCount=*/0, sourceDims, rewriter.getContext()),
+  return AffineMap::get(/*dimCount=*/resultRank,
+                        /*symbolCount=*/0, sourceDims, rewriter.getContext());
+}
 
-      // Output indexing map.
-      rewriter.getMultiDimIdentityMap(resultRank)};
+// Broadcast the source value to all the outer dimensions of the result value.
+// If required, the element type is expanded using an arith.extsi operation.
+static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
+                                                Location loc, Value source,
+                                                Value result) {
+  ShapedType resultTy = cast<ShapedType>(result.getType());
+  const int64_t resultRank = resultTy.getRank();
+  // Creating maps for the input and output of the broacast-like generic op.
+  SmallVector<AffineMap, 2> indexingMaps;
+  indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
+  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
   // Build the broadcast-like operation as a linalg.generic.
   return rewriter
@@ -488,14 +494,6 @@ class DepthwiseConvConverter
                                weightShape[2], weightShape[3]},
                               resultETy);
 
-    // Broadcast the initial value to the output tensor before convolving.
-    SmallVector<AffineMap, 4> indexingMaps;
-    indexingMaps.push_back(AffineMap::get(
-        /*dimCount=*/resultRank, /*symbolCount=*/0,
-        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
-
     auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, linalgConvTy.getShape(), resultETy, filteredDims);
@@ -507,6 +505,13 @@ class DepthwiseConvConverter
 
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, resultTy.getShape(), resultETy, filteredDims);
+
+    // Broadcast the initial value to the output tensor before convolving.
+    SmallVector<AffineMap, 4> indexingMaps;
+    indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
+
     if (!isQuantized) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 0d55d1899c713e..bfdc72ee07e97f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -702,6 +702,22 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @depthwise_conv_scalar_bias
+func.func @depthwise_conv_scalar_bias(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
+  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %{{.*}} : tensor<1xf32>, tensor<1x5x5x33xf32>) outs(%{{.*}} : tensor<1x5x5x33xf32>) {
+  // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %{{.*}}: f32):
+  // CHECK:   [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
+  // CHECK:   linalg.yield [[ADD]] : f32
+  // CHECK: } -> tensor<1x5x5x33xf32>
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>)  -> tensor<1x5x5x33xf32>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
@@ -840,6 +856,20 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
 
 // -----
 
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
+// CHECK-LABEL: @conv3d_scalar_bias_f32
+func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
+  // CHECK:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+  %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>)  -> tensor<1x47x45x43x28xf32>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/110806


More information about the Mlir-commits mailing list