[Mlir-commits] [mlir] [mlir][linalg] Relax tensor.extract vectorization (PR #99299)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Aug 1 09:09:56 PDT 2024


banach-space wrote:

> > > Hey! I was looking at the changes and thinking that we may end up introducing too much complexity if we add support for all these unit dimension special cases.
> > 
> > 
> > Thanks for taking a look! Note that this is actual simplifying the current logic 😅 (and reducing the number of special cases).
> 
> The way I see it is that it's replacing an early exit with a special case.
> 

Well, I am replacing:
```cpp
if (some_complex_condition)
  return VectorMemoryAccessKind::Gather;
  
if (some_other_complex_condition)
  return VectorMemoryAccessKind::Gather;
```

with:
```cpp
  // One less complex condition
  if (!isOutput1DVector)
    return VectorMemoryAccessKind::Gather;
```

😅 

> Would you have an example of how the gather would look like after vectorization and after removing the unit dim? That would be helpful to make a call.

**BEFORE THIS CHANGE**
Sure!
```mlir
  func.func @vectorize_scalar_broadcast_column_tensor(%arg0: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<[[0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
    %c1 = arith.constant 1 : index
    %c1_0 = arith.constant 1 : index
    %c4_1 = arith.constant 4 : index
    %c0_2 = arith.constant 0 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %c0_i32 : tensor<1x1x4xi32>, vector<1x1x4xi32>
    %1 = vector.step : vector<1xindex>
    %2 = vector.broadcast %1 : vector<1xindex> to vector<4x1x1xindex>
    %3 = vector.transpose %2, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %4 = vector.step : vector<1xindex>
    %5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
    %6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %cst_3 = arith.constant dense<true> : vector<1x1x4xi1>
    %cst_4 = arith.constant dense<0> : vector<1x1x4xi32>
    %c0_5 = arith.constant 0 : index
    %c1_6 = arith.constant 1 : index
    %dim = tensor.dim %cst, %c1_6 : tensor<30x1xi32>
    %7 = vector.broadcast %dim : index to vector<1x1x4xindex>
    %8 = arith.muli %6, %7 : vector<1x1x4xindex>
    %cst_7 = arith.constant dense<0> : vector<1x1x4xindex>
    %9 = arith.addi %cst_7, %8 : vector<1x1x4xindex>
    %10 = vector.gather %cst[%c0_5, %c0_5] [%9], %cst_3, %cst_4 : tensor<30x1xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
    %c0_8 = arith.constant 0 : index
    %11 = vector.transfer_write %10, %arg0[%c0_8, %c0_8, %c0_8] : vector<1x1x4xi32>, tensor<1x1x4xi32>
    return %11 : tensor<1x1x4xi32>
  }
```

**AFTER THIS CHANGE**
```mlir
  func.func @vectorize_scalar_broadcast_column_tensor(%arg0: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<[[0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
    %c1 = arith.constant 1 : index
    %c1_0 = arith.constant 1 : index
    %c4_1 = arith.constant 4 : index
    %c0_2 = arith.constant 0 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %c0_i32 : tensor<1x1x4xi32>, vector<1x1x4xi32>
    %1 = vector.step : vector<1xindex>
    %2 = vector.broadcast %1 : vector<1xindex> to vector<4x1x1xindex>
    %3 = vector.transpose %2, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %4 = vector.step : vector<1xindex>
    %5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
    %6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %cst_3 = arith.constant dense<true> : vector<1x1x4xi1>
    %cst_4 = arith.constant dense<0> : vector<1x1x4xi32>
    %c0_5 = arith.constant 0 : index
    %c0_i32_6 = arith.constant 0 : i32
    %7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
    %8 = vector.extractelement %7[%c0_i32_6 : i32] : vector<4xindex>
    %c0_i32_7 = arith.constant 0 : i32
    %9 = vector.transfer_read %cst[%8, %c0], %c0_i32_7 {in_bounds = [true, true, true], permutation_map = #map} : tensor<30x1xi32>, vector<1x1x4xi32>
    %c0_8 = arith.constant 0 : index
    %10 = vector.transfer_write %9, %arg0[%c0_8, %c0_8, %c0_8] : vector<1x1x4xi32>, tensor<1x1x4xi32>
    return %10 : tensor<1x1x4xi32>
  }
```

So, we'd need to match:
```mlir
    %4 = vector.step : vector<1xindex>
    %5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
    %6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %dim = tensor.dim %cst, %c1_6 : tensor<30x1xi32>
    %7 = vector.broadcast %dim : index to vector<1x1x4xindex>
    %8 = arith.muli %6, %7 : vector<1x1x4xindex>
    %cst_7 = arith.constant dense<0> : vector<1x1x4xindex>
    %9 = arith.addi %cst_7, %8 : vector<1x1x4xindex>
    %10 = vector.gather %cst[%c0_5, %c0_5] [%9], %cst_3, %cst_4 : tensor<30x1xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
```

Not that bad, but that's "roughly" what the vectoriser matches today to decide that the underlying `tensor.extract` is a broadcast of a scalar. IIUC, you are suggesting removing that logic from the vectorizer and creating a Vector dialect pattern instead? Would you go as far as simplifying the vectorizer to always generate `vector.gather` and the let Vector patterns "lower" that to something more efficient?

https://github.com/llvm/llvm-project/pull/99299


More information about the Mlir-commits mailing list