[Mlir-commits] [mlir] [mlir][tosa] Convert tosa.transpose_conv2d to linalg.generic directly (PR #79824)
Hsiangkai Wang
llvmlistbot at llvm.org
Tue Jan 30 15:10:47 PST 2024
Hsiangkai wrote:
> Hi, Hsiangkai. Thank you for your contribution. A first look at the proposed lowering led me to the following questions and suggestions:
>
> * Does this code work for any combination of input arguments and attributes?
> * What would it take to have full support for dynamic dimensions? There is technically no requirement for `tosa.transpose_conv2d` to use static dimensions for weight or bias, which makes the corresponding error message rather misleading. We (at MathWorks) have recently been working on comprehensive lowerings for the `tfl` (Tensorflow Lite) dialect counterparts of 2D/3D transpose convs. We're aware it gets significantly trickier to support dynamic dims, but it is possible to emit generic code that is also efficient when only static dimensions are used (by relying on op folders).
> * Has the correctness of this lowering been validated with any sort of runtime framework? We can help with that, but it would be desirable to first make sure that this lowering is complete according to the TOSA op spec.
> * The test coverage looks minimal. An extensive set of unit tests should verify that the lowering strategy works correctly for a reasonable set of input argument/attribute combinations. But again, since FileCheck tests are very sensitive to the exact sequence of emitted ops, it'd desirable to attain full spec support before addressing this.
> * A thorough verification of valid combinations of input arguments can save lots of headaches to end users (e.g., verify that `out_shape` is consistent with input/weight/stride/pad values when known, verify that `out_shape` matches output shape when static, verify that the number of input channels matches in input/weights when known, etc.). Failing to verify these will likely lead to obscure error messages when the `linalg.generic` op is emitted. Maybe some of these verifications belong in the op verifier instead. My personal preference here is including most lightweight verifications in the op verifier, and then adding `assert`s in the lowering with a comment indicating that the op verifier enforces a certain invariant.
> * Is this intended to act as a replacement to the `TosaDecomposeTranposeConv` pass? If so, is there still a point in keeping that pass?
> * It is not clear to me that it is preferable to lower directly to `linalg.generic` versus the decomposition approach based on pads + reverses + conv2d. Our approach when working on TFL lowerings was the latter, with the benefit that the resulting `tosa.conv2d`/`linalg.conv2d_xxx` ops may be mapped to device-specific intrinsics for better performance. Maybe allowing the user to optionally run the decomposition pass ahead of time would be a good tradeoff here.
Hi Rafael,
Thanks for your review and comments. I removed the static shape constraints on input and weight. The pattern can work on dynamic shapes of input, weight, and bias. However, there is an attribute, out_shape, to specify the result shape. So, I modified the patch to add static shape constraint on result.
* I add more test cases, including dynamic shape with paddings and strides. The codegen after `convert-linalg-to-parallel-loops` has the same semantics of TransposeConv2D in the TOSA specification.
```
func.func @global_tosa_transpose_conv2d_kernel(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?xf32>, %arg3: memref<1x1x3x1xf32>) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%dim = memref.dim %arg2, %c0 : memref<?xf32>
scf.parallel (%arg4, %arg5, %arg6, %arg7) = (%c0, %c0, %c0, %c0) to (%c1, %c1, %c3, %dim) step (%c1, %c1, %c1, %c1) {
%0 = memref.load %arg2[%arg7] : memref<?xf32>
memref.store %0, %arg3[%arg4, %arg5, %arg6, %arg7] : memref<1x1x3x1xf32>
scf.reduce
}
%dim_0 = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
%dim_1 = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
%dim_2 = memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
%dim_3 = memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
%dim_4 = memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
%dim_5 = memref.dim %arg1, %c1 : memref<?x?x?x?xf32>
%dim_6 = memref.dim %arg1, %c2 : memref<?x?x?x?xf32>
scf.parallel (%arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10) = (%c0, %c0, %c0, %c0, %c0, %c0, %c0) to (%dim_0, %dim_1, %dim_2, %dim_5, %dim_6, %dim_4, %dim_3) step (%c1, %c1, %c1, %c1, %c1, %c1, %c1) {
%0 = memref.load %arg0[%arg4, %arg5, %arg6, %arg10] : memref<?x?x?x?xf32>
%1 = memref.load %arg1[%arg9, %arg7, %arg8, %arg10] : memref<?x?x?x?xf32>
%2 = affine.apply affine_map<(d0, d1) -> (d0 * 5 + d1 + 1)>(%arg5, %arg7)
%3 = affine.apply affine_map<(d0, d1) -> (d0 * 6 + d1 + 3)>(%arg6, %arg8)
%4 = memref.load %arg3[%arg4, %2, %3, %arg9] : memref<1x1x3x1xf32>
%5 = affine.apply affine_map<(d0, d1) -> (d0 * 5 + d1 + 1)>(%arg5, %arg7)
%6 = affine.apply affine_map<(d0, d1) -> (d0 * 6 + d1 + 3)>(%arg6, %arg8)
%7 = arith.mulf %0, %1 : f32
%8 = arith.addf %7, %4 : f32
memref.store %8, %arg3[%arg4, %5, %6, %arg9] : memref<1x1x3x1xf32>
scf.reduce
}
return
}
```
* No, I do not try it on any runtime framework. I appreciate your willing to help.
* I will add more verification as asserts later.
* No, my intention is not to replace `TosaDecomposeTranposeConv`. I just want to provide another codegen option to users. I agree with you that to allow users optionally run the decomposition pass first would be good. I will think about how to modify the patch to not use the pattern by default. Do you have any suggestions about how to do it?
* About performance, I have no answer which one is better. I am still a newbie in MLIR.
Thanks again for your precious time.
https://github.com/llvm/llvm-project/pull/79824
More information about the Mlir-commits
mailing list