[Mlir-commits] [mlir] [mlir][tosa] Fix lowering of tosa.conv2d (PR #73240)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 23 05:12:06 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Spenser Bauman (sabauma)
<details>
<summary>Changes</summary>
The lowering of tosa.conv2d produces an illegal tensor.empty operation where the number of inputs do not match the number of dynamic dimensions in the output type.
The fix is to base the generation of tensor.dim operations off the result type of the conv2d operation, rather than the input type. The problem and fix are very similar to this fix
https://github.com/llvm/llvm-project/pull/72724
but for convolution.
---
Full diff: https://github.com/llvm/llvm-project/pull/73240.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+2-2)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+23)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 9e374be534985e5..328fdac461e3de4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -136,7 +136,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
int64_t inputDim = inputSizeDims[i];
int64_t kernelDim = kernelSizeDims[i];
- if (inputTy.isDynamicDim(inputDim)) {
+ if (resultTy.isDynamicDim(inputDim)) {
auto padTop = padAttr[i * 2];
auto padBottom = padAttr[i * 2 + 1];
auto stride = strideAttr[i];
@@ -153,7 +153,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
// Get the batch/channels dimensions.
for (int i = 0; i < inputRank; i++) {
- if (inputTy.isDynamicDim(i) && !dynDims[i])
+ if (resultTy.isDynamicDim(i) && !dynDims[i])
dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 4edc75331932803..6bbaf6dacdb53e0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -497,6 +497,29 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
// -----
+// CHECK: [[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: [[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3x4xf32>, %bias: tensor<4xf32>) {
+ // %[[C0:.+]] = arith.constant 0 : index
+ // %[[DIM0:.+]] = tensor.dim %input, %[[C0]] : tensor<2x6x5x4xf32>
+ // %[[INIT_CONV:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x4x3x4xf32>
+ // %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+ // %[[FILL:.+]] = linalg.fill
+ // %[[INIT_GENERIC:.+]] = tensor.empty([[DIM0]]) : tensor<?x4x3x4xf32>
+
+ // %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x4xf32>, tensor<4x3x3x4xf32>) outs(%[[INIT_CONV]] : tensor<?x4x3x4xf32>) -> tensor<?x4x3x4xf32>
+ // linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<4xf32>, tensor<?x4x3x4xf32>) outs(%[[INIT_GENERIC]] : tensor<?x4x3x4xf32>) {
+ // %[[ADD:.+]] = arith.addf
+ // linalg.yield %[[ADD]] : f32
+ // } -> tensor<?x4x3x4xf32>
+
+ %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @conv2d_padded_f32
func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
``````````
</details>
https://github.com/llvm/llvm-project/pull/73240
More information about the Mlir-commits
mailing list