[Mlir-commits] [mlir] f58fb8c - [mlir][tosa] Fix lowering of tosa.conv2d (#73240)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 1 07:33:19 PST 2023
Author: Spenser Bauman
Date: 2023-12-01T15:33:14Z
New Revision: f58fb8c209a5179f8f2e02e2a0816c9b1f1edb1b
URL: https://github.com/llvm/llvm-project/commit/f58fb8c209a5179f8f2e02e2a0816c9b1f1edb1b
DIFF: https://github.com/llvm/llvm-project/commit/f58fb8c209a5179f8f2e02e2a0816c9b1f1edb1b.diff
LOG: [mlir][tosa] Fix lowering of tosa.conv2d (#73240)
The lowering of tosa.conv2d produces an illegal tensor.empty operation
where the number of inputs do not match the number of dynamic dimensions
in the output type.
The fix is to base the generation of tensor.dim operations off the
result type of the conv2d operation, rather than the input type. The
problem and fix are very similar to this fix
https://github.com/llvm/llvm-project/pull/72724
but for convolution.
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index b30651976eeb939..0accd9d1986a1ed 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -179,7 +179,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
int64_t inputDim = inputSizeDims[i];
int64_t kernelDim = kernelSizeDims[i];
- if (inputTy.isDynamicDim(inputDim)) {
+ if (resultTy.isDynamicDim(inputDim)) {
auto padTop = padAttr[i * 2];
auto padBottom = padAttr[i * 2 + 1];
auto stride = strideAttr[i];
@@ -196,7 +196,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
// Get the batch/channels dimensions.
for (int i = 0; i < inputRank; i++) {
- if (inputTy.isDynamicDim(i) && !dynDims[i])
+ if (resultTy.isDynamicDim(i) && !dynDims[i])
dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index bbdd1bad799865d..230001f7633b570 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -495,6 +495,29 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
// -----
+// CHECK: [[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: [[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3x4xf32>, %bias: tensor<4xf32>) {
+ // %[[C0:.+]] = arith.constant 0 : index
+ // %[[DIM0:.+]] = tensor.dim %input, %[[C0]] : tensor<2x6x5x4xf32>
+ // %[[INIT_CONV:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x4x3x4xf32>
+ // %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+ // %[[FILL:.+]] = linalg.fill
+ // %[[INIT_GENERIC:.+]] = tensor.empty([[DIM0]]) : tensor<?x4x3x4xf32>
+
+ // %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x4xf32>, tensor<4x3x3x4xf32>) outs(%[[INIT_CONV]] : tensor<?x4x3x4xf32>) -> tensor<?x4x3x4xf32>
+ // linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<4xf32>, tensor<?x4x3x4xf32>) outs(%[[INIT_GENERIC]] : tensor<?x4x3x4xf32>) {
+ // %[[ADD:.+]] = arith.addf
+ // linalg.yield %[[ADD]] : f32
+ // } -> tensor<?x4x3x4xf32>
+
+ %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @conv2d_padded_f32
func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
More information about the Mlir-commits
mailing list