[all-commits] [llvm/llvm-project] 2ae37b: Allow empty dimension arrays in `linalg::inferCont...

bjacob via All-commits all-commits at lists.llvm.org
Thu Oct 19 10:13:44 PDT 2023


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 2ae37be4b433632e46209aa04fcc857675783f81
      https://github.com/llvm/llvm-project/commit/2ae37be4b433632e46209aa04fcc857675783f81
  Author: bjacob <jacob.benoit.1 at gmail.com>
  Date:   2023-10-19 (Thu, 19 Oct 2023)

  Changed paths:
    M mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    M mlir/test/Dialect/Linalg/match-ops-interpreter.mlir

  Log Message:
  -----------
  Allow empty dimension arrays in `linalg::inferContractionDims` (#69496)

This function was returning failure when any of the intersection sets
was empty, but this is actually legitimate in "matrix times vector"
cases, where some of the operands have lower dimensionality, implying
unit-dimension semantics for the "missing" dimensions.

Example:

```mlir
func.func @transpose_extend_batch_matmul(
    %vec: tensor<32x128xi16>,
    %mat: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
  %c0_i32 = arith.constant 0 : i32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<11008x32xi32>
  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
  %2 = tensor.empty() : tensor<11008xf32>
  %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<11008xf32>) -> tensor<11008xf32>
  %batch_matmul_result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1)>], 
                                         iterator_types = ["parallel", "parallel", "reduction"]} 
                                         ins(%vec, %mat : tensor<32x128xi16>, tensor<11008x32x128xi4>) 
                                         outs(%1 : tensor<11008x32xi32>) {
  ^bb0(%in: i16, %in_3: i4, %out: i32):
      %19 = arith.extsi %in : i16 to i32
      %20 = arith.extui %in_3 : i4 to i32
      %21 = arith.muli %19, %20 : i32
      %22 = arith.addi %21, %out : i32
      linalg.yield %22 : i32
  } -> tensor<11008x32xi32>
  return %batch_matmul_result : tensor<11008x32xi32>
}
```

Here, we were returning failure because `ac` is empty. With this PR, we
return this useful information:

```
batch: [ 1 ]
m: [ ]
n: [ 0 ]
k: [ 2 ]
```




More information about the All-commits mailing list