[all-commits] [llvm/llvm-project] 843f1f: [mlir][scf] Add scf.for + tensor.cast canonicaliza...

Nicolas Vasilache via All-commits all-commits at lists.llvm.org
Fri Apr 16 09:55:18 PDT 2021


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 843f1fc82598216a4be672ba51820b037dae106b
      https://github.com/llvm/llvm-project/commit/843f1fc82598216a4be672ba51820b037dae106b
  Author: Nicolas Vasilache <nicolas.vasilache at gmail.com>
  Date:   2021-04-16 (Fri, 16 Apr 2021)

  Changed paths:
    M mlir/lib/Dialect/SCF/SCF.cpp
    M mlir/test/Dialect/SCF/canonicalize.mlir

  Log Message:
  -----------
  [mlir][scf] Add scf.for + tensor.cast canonicalization pattern

Fold scf.for iter_arg/result pairs that go through incoming/ougoing
a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:

```
  %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
  %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
     -> (tensor<?x?xf32>) {
    %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    scf.yield %2 : tensor<?x?xf32>
  }
  %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
  use_of(%2)
```

folds into:

```
  %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
      -> (tensor<32x1024xf32>) {
    %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
    %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
    scf.yield %4 : tensor<32x1024xf32>
  }
  use_of(%0)
```

Differential Revision: https://reviews.llvm.org/D100661




More information about the All-commits mailing list