[all-commits] [llvm/llvm-project] 0a94d3: [mlir][tosa] Fix tosa-infer-shapes crash (#87234)

Spenser Bauman via All-commits all-commits at lists.llvm.org
Tue Apr 2 16:45:49 PDT 2024


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 0a94d35bfb81cb0bef60ebe60513d191661da0bd
      https://github.com/llvm/llvm-project/commit/0a94d35bfb81cb0bef60ebe60513d191661da0bd
  Author: Spenser Bauman <sbauman at mathworks.com>
  Date:   2024-04-02 (Tue, 02 Apr 2024)

  Changed paths:
    M mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    M mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

  Log Message:
  -----------
  [mlir][tosa] Fix tosa-infer-shapes crash (#87234)

The tosa-infer-shapes pass inserts tensor.cast operations to mediate
refined result types with consumers whose types cannot be refined. This
process interferes with how types are refined in tosa.while_loop body
regions, where types are propagated speculatively (to determine the
types of the tosa.yield terminator) and then reverted.

The new tosa.cast operations result in a crash due to not having types
associated to them for the reversion process.

This change modifies the shape propagation behavior so that the
introduction to tensor.cast operations behaves better with this type
reversion process. The new behavior is to only introduce tensor.cast
operations once we wish to commit the newly computed types to the IR.

This is an example causing the crash:

```mlir
func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
  %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>

  %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
    %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
    tosa.yield %3 : tensor<*xi1>
  } do {
  ^bb0(%arg1: tensor<*xi32>):
    // Inferrable operation whose type will refine to tensor<i32>
    %3 = tosa.add %arg1, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>

    // Non-inferrable use site, will require the cast:
    //     tensor.cast %3 : tensor<i32> to tensor<*xi32>
    // 
    // The new cast operation will result in accessing undefined memory through
    // originalTypeMap in the C++ code.
    "use"(%3) : (tensor<*xi32>) -> ()
    tosa.yield %3 : tensor<*xi32>
  }

  return %1 : tensor<*xi32>
}
```

The `tensor.cast` operation inserted in the loop body causes a failure
in the code which resets the types after propagation through the loop
body:

```c++
// The types inferred in the block assume the operand types specified for
// this iteration. We need to restore the original types to ensure that
// future iterations only use the already specified types, not possible
// types from previous iterations.
for (auto &block : bodyRegion) {
  for (auto arg : block.getArguments())
    arg.setType(originalTypeMap[arg]);
  for (auto &op : block)
    for (auto result : op.getResults())
      result.setType(originalTypeMap[result]);  // problematic access
}
```

---------

Co-authored-by: Spenser Bauman <sabauma at fastmail>



To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications


More information about the All-commits mailing list