[Mlir-commits] [mlir] [mlir][tosa] Fix lowering of tosa.conv2d (PR #73240)

Spenser Bauman llvmlistbot at llvm.org
Thu Nov 23 05:11:37 PST 2023


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/73240

The lowering of tosa.conv2d produces an illegal tensor.empty operation where the number of inputs do not match the number of dynamic dimensions in the output type.

The fix is to base the generation of tensor.dim operations off the result type of the conv2d operation, rather than the input type. The problem and fix are very similar to this fix

https://github.com/llvm/llvm-project/pull/72724

but for convolution.

>From d45ff5b95d0ef4e8a1705088f9bf9c94da642060 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Wed, 22 Nov 2023 18:41:52 -0500
Subject: [PATCH] [mlir][tosa] Fix lowering of tosa.conv2d

The lowering of tosa.conv2d produces an illegal tensor.empty operation
where the number of inputs do not match the number of dynamic dimensions
in the output type.

The fix is to base the generation of tensor.dim operations off the
result type of the conv2d operation, rather than the input type.
The problem and fix are very similar to this fix

https://github.com/llvm/llvm-project/pull/72724

but for convolution.
---
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  4 ++--
 .../TosaToLinalg/tosa-to-linalg-named.mlir    | 23 +++++++++++++++++++
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 9e374be534985e5..328fdac461e3de4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -136,7 +136,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
   for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
     int64_t inputDim = inputSizeDims[i];
     int64_t kernelDim = kernelSizeDims[i];
-    if (inputTy.isDynamicDim(inputDim)) {
+    if (resultTy.isDynamicDim(inputDim)) {
       auto padTop = padAttr[i * 2];
       auto padBottom = padAttr[i * 2 + 1];
       auto stride = strideAttr[i];
@@ -153,7 +153,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
 
   // Get the batch/channels dimensions.
   for (int i = 0; i < inputRank; i++) {
-    if (inputTy.isDynamicDim(i) && !dynDims[i])
+    if (resultTy.isDynamicDim(i) && !dynDims[i])
       dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
   }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 4edc75331932803..6bbaf6dacdb53e0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -497,6 +497,29 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
 
 // -----
 
+// CHECK: [[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: [[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3x4xf32>, %bias: tensor<4xf32>) {
+  // %[[C0:.+]] = arith.constant 0 : index
+  // %[[DIM0:.+]] = tensor.dim %input, %[[C0]] : tensor<2x6x5x4xf32>
+  // %[[INIT_CONV:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x4x3x4xf32>
+  // %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // %[[FILL:.+]] = linalg.fill
+  // %[[INIT_GENERIC:.+]] = tensor.empty([[DIM0]]) : tensor<?x4x3x4xf32>
+
+  // %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x4xf32>, tensor<4x3x3x4xf32>) outs(%[[INIT_CONV]] : tensor<?x4x3x4xf32>) -> tensor<?x4x3x4xf32>
+  // linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<4xf32>, tensor<?x4x3x4xf32>) outs(%[[INIT_GENERIC]] : tensor<?x4x3x4xf32>) {
+  //   %[[ADD:.+]] = arith.addf
+  //   linalg.yield %[[ADD]] : f32
+  // } -> tensor<?x4x3x4xf32>
+
+  %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32    >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_padded_f32
 func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0



More information about the Mlir-commits mailing list