[all-commits] [llvm/llvm-project] ef6e7a: [mlir] [tosa] Bug fixes in shape inference pass (#...

Rafael Ubal via All-commits all-commits at lists.llvm.org
Fri Aug 16 08:11:10 PDT 2024


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: ef6e7affbb7b0eb4976c1019c788bcadfc34ecd6
      https://github.com/llvm/llvm-project/commit/ef6e7affbb7b0eb4976c1019c788bcadfc34ecd6
  Author: Rafael Ubal <rubal at mathworks.com>
  Date:   2024-08-16 (Fri, 16 Aug 2024)

  Changed paths:
    M mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    M mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

  Log Message:
  -----------
  [mlir] [tosa] Bug fixes in shape inference pass (#104146)

This change addresses 2 bugs in the TOSA shape inference pass
(`--tosa-infer-shapes`). The included unit test contains a detailed
description of the issues.

- Input IR

```
func.func @main(%arg0: tensor<1x2x8xf32>) {
  %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>

  %c0 = arith.constant 0 : index
  %dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>

  %expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  %expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  return
}
```

- Output IR

```
module {
  func.func @main(%arg0: tensor<1x2x8xf32>) {
  %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>

  // This cast was previously inserted between both 'tensor.expand_shape' ops.
  %cast = tensor.cast %0 : tensor<1x2x8xf32> to tensor<?x2x8xf32>

  %c0 = arith.constant 0 : index
  %dim = tensor.dim %0, %c0 : tensor<1x2x8xf32>

  // The operand of the first 'tensor.expand_shape' op was not previously updated
  // from '%0' to '%cast' due to an invalidation of the iterator traversing the
  // use list of the 'tosa.cast' op.
  %expanded_0 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  %expanded_1 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  return
}
```



To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications


More information about the All-commits mailing list