[all-commits] [llvm/llvm-project] 6b4b63: Lowering for 'tosa.scatter'

rafaelubalmw via All-commits all-commits at lists.llvm.org
Tue May 30 14:34:22 PDT 2023


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 6b4b63a832f105039442fc983d0b309abe5261d5
      https://github.com/llvm/llvm-project/commit/6b4b63a832f105039442fc983d0b309abe5261d5
  Author: Rafael Ubal Tena <rubal at mathworks.com>
  Date:   2023-05-30 (Tue, 30 May 2023)

  Changed paths:
    M mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
    M mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
    M mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir

  Log Message:
  -----------
  Lowering for 'tosa.scatter'

This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering:

```
func.func @tosa(
                %valuesIn : tensor<3x7x5xi32>,
                %indices : tensor<3x6xi32>,
                %input : tensor<3x6x5xi32>) ->
                tensor<3x7x5xi32> {
        %0 = "tosa.scatter"(%valuesIn, %indices, %input) :
                        (tensor<3x7x5xi32>,
                        tensor<3x6xi32>,
                        tensor<3x6x5xi32>) ->
                        (tensor<3x7x5xi32>)
        return %0 : tensor<3x7x5xi32>
}
```

translates to
  func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
    %c0 = arith.constant 0 : index
    %c3 = arith.constant 3 : index
    %c1 = arith.constant 1 : index
    %c6 = arith.constant 6 : index
    %c2 = arith.constant 2 : index
    %c5 = arith.constant 5 : index
    %c0_0 = arith.constant 0 : index
    %c1_1 = arith.constant 1 : index
    %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) {
      %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) {
        %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32>
        %2 = arith.index_cast %extracted : i32 to index
        %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
        %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
        scf.yield %inserted_slice : tensor<3x7x5xi32>
      }
      scf.yield %1 : tensor<3x7x5xi32>
    }
    return %0 : tensor<3x7x5xi32>
  }
```

We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons:

- The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).
- The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D151117




More information about the All-commits mailing list