[Mlir-commits] [mlir] [mlir][tosa] Fix lowering of tosa.conv2d (PR #73240)

Spenser Bauman llvmlistbot at llvm.org
Fri Dec 1 07:12:49 PST 2023


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/73240

>From e141181c3e06d0941bdac5dfcae0c95a72b7670b Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Wed, 22 Nov 2023 18:41:52 -0500
Subject: [PATCH] [mlir][tosa] Fix lowering of tosa.conv2d

The lowering of tosa.conv2d produces an illegal tensor.empty operation
where the number of inputs do not match the number of dynamic dimensions
in the output type.

The fix is to base the generation of tensor.dim operations off the
result type of the conv2d operation, rather than the input type.
The problem and fix are very similar to this fix

https://github.com/llvm/llvm-project/pull/72724

but for convolution.
---
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  4 ++--
 .../TosaToLinalg/tosa-to-linalg-named.mlir    | 23 +++++++++++++++++++
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 9e374be534985e5..328fdac461e3de4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -136,7 +136,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
   for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
     int64_t inputDim = inputSizeDims[i];
     int64_t kernelDim = kernelSizeDims[i];
-    if (inputTy.isDynamicDim(inputDim)) {
+    if (resultTy.isDynamicDim(inputDim)) {
       auto padTop = padAttr[i * 2];
       auto padBottom = padAttr[i * 2 + 1];
       auto stride = strideAttr[i];
@@ -153,7 +153,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
 
   // Get the batch/channels dimensions.
   for (int i = 0; i < inputRank; i++) {
-    if (inputTy.isDynamicDim(i) && !dynDims[i])
+    if (resultTy.isDynamicDim(i) && !dynDims[i])
       dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
   }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 4edc75331932803..6bbaf6dacdb53e0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -497,6 +497,29 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
 
 // -----
 
+// CHECK: [[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: [[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3x4xf32>, %bias: tensor<4xf32>) {
+  // %[[C0:.+]] = arith.constant 0 : index
+  // %[[DIM0:.+]] = tensor.dim %input, %[[C0]] : tensor<2x6x5x4xf32>
+  // %[[INIT_CONV:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x4x3x4xf32>
+  // %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // %[[FILL:.+]] = linalg.fill
+  // %[[INIT_GENERIC:.+]] = tensor.empty([[DIM0]]) : tensor<?x4x3x4xf32>
+
+  // %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x4xf32>, tensor<4x3x3x4xf32>) outs(%[[INIT_CONV]] : tensor<?x4x3x4xf32>) -> tensor<?x4x3x4xf32>
+  // linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<4xf32>, tensor<?x4x3x4xf32>) outs(%[[INIT_GENERIC]] : tensor<?x4x3x4xf32>) {
+  //   %[[ADD:.+]] = arith.addf
+  //   linalg.yield %[[ADD]] : f32
+  // } -> tensor<?x4x3x4xf32>
+
+  %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32    >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_padded_f32
 func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0



More information about the Mlir-commits mailing list