[Mlir-commits] [mlir] [mlir][tosa] Fix lowering of tosa.conv2d (PR #73240)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 23 05:12:06 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Spenser Bauman (sabauma)

<details>
<summary>Changes</summary>

The lowering of tosa.conv2d produces an illegal tensor.empty operation where the number of inputs do not match the number of dynamic dimensions in the output type.

The fix is to base the generation of tensor.dim operations off the result type of the conv2d operation, rather than the input type. The problem and fix are very similar to this fix

https://github.com/llvm/llvm-project/pull/72724

but for convolution.

---
Full diff: https://github.com/llvm/llvm-project/pull/73240.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+2-2) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+23) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 9e374be534985e5..328fdac461e3de4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -136,7 +136,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
   for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
     int64_t inputDim = inputSizeDims[i];
     int64_t kernelDim = kernelSizeDims[i];
-    if (inputTy.isDynamicDim(inputDim)) {
+    if (resultTy.isDynamicDim(inputDim)) {
       auto padTop = padAttr[i * 2];
       auto padBottom = padAttr[i * 2 + 1];
       auto stride = strideAttr[i];
@@ -153,7 +153,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
 
   // Get the batch/channels dimensions.
   for (int i = 0; i < inputRank; i++) {
-    if (inputTy.isDynamicDim(i) && !dynDims[i])
+    if (resultTy.isDynamicDim(i) && !dynDims[i])
       dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
   }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 4edc75331932803..6bbaf6dacdb53e0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -497,6 +497,29 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
 
 // -----
 
+// CHECK: [[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: [[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3x4xf32>, %bias: tensor<4xf32>) {
+  // %[[C0:.+]] = arith.constant 0 : index
+  // %[[DIM0:.+]] = tensor.dim %input, %[[C0]] : tensor<2x6x5x4xf32>
+  // %[[INIT_CONV:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x4x3x4xf32>
+  // %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // %[[FILL:.+]] = linalg.fill
+  // %[[INIT_GENERIC:.+]] = tensor.empty([[DIM0]]) : tensor<?x4x3x4xf32>
+
+  // %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x4xf32>, tensor<4x3x3x4xf32>) outs(%[[INIT_CONV]] : tensor<?x4x3x4xf32>) -> tensor<?x4x3x4xf32>
+  // linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<4xf32>, tensor<?x4x3x4xf32>) outs(%[[INIT_GENERIC]] : tensor<?x4x3x4xf32>) {
+  //   %[[ADD:.+]] = arith.addf
+  //   linalg.yield %[[ADD]] : f32
+  // } -> tensor<?x4x3x4xf32>
+
+  %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32    >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_padded_f32
 func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0

``````````

</details>


https://github.com/llvm/llvm-project/pull/73240


More information about the Mlir-commits mailing list