[Mlir-commits] [mlir] [mlir][linalg] Relax tensor.extract vectorization (PR #99299)
Andrzej Warzyński
llvmlistbot at llvm.org
Thu Aug 1 09:09:56 PDT 2024
banach-space wrote:
> > > Hey! I was looking at the changes and thinking that we may end up introducing too much complexity if we add support for all these unit dimension special cases.
> >
> >
> > Thanks for taking a look! Note that this is actual simplifying the current logic 😅 (and reducing the number of special cases).
>
> The way I see it is that it's replacing an early exit with a special case.
>
Well, I am replacing:
```cpp
if (some_complex_condition)
return VectorMemoryAccessKind::Gather;
if (some_other_complex_condition)
return VectorMemoryAccessKind::Gather;
```
with:
```cpp
// One less complex condition
if (!isOutput1DVector)
return VectorMemoryAccessKind::Gather;
```
😅
> Would you have an example of how the gather would look like after vectorization and after removing the unit dim? That would be helpful to make a call.
**BEFORE THIS CHANGE**
Sure!
```mlir
func.func @vectorize_scalar_broadcast_column_tensor(%arg0: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<[[0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
%c1 = arith.constant 1 : index
%c1_0 = arith.constant 1 : index
%c4_1 = arith.constant 4 : index
%c0_2 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%0 = vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %c0_i32 : tensor<1x1x4xi32>, vector<1x1x4xi32>
%1 = vector.step : vector<1xindex>
%2 = vector.broadcast %1 : vector<1xindex> to vector<4x1x1xindex>
%3 = vector.transpose %2, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
%4 = vector.step : vector<1xindex>
%5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
%6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
%cst_3 = arith.constant dense<true> : vector<1x1x4xi1>
%cst_4 = arith.constant dense<0> : vector<1x1x4xi32>
%c0_5 = arith.constant 0 : index
%c1_6 = arith.constant 1 : index
%dim = tensor.dim %cst, %c1_6 : tensor<30x1xi32>
%7 = vector.broadcast %dim : index to vector<1x1x4xindex>
%8 = arith.muli %6, %7 : vector<1x1x4xindex>
%cst_7 = arith.constant dense<0> : vector<1x1x4xindex>
%9 = arith.addi %cst_7, %8 : vector<1x1x4xindex>
%10 = vector.gather %cst[%c0_5, %c0_5] [%9], %cst_3, %cst_4 : tensor<30x1xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
%c0_8 = arith.constant 0 : index
%11 = vector.transfer_write %10, %arg0[%c0_8, %c0_8, %c0_8] : vector<1x1x4xi32>, tensor<1x1x4xi32>
return %11 : tensor<1x1x4xi32>
}
```
**AFTER THIS CHANGE**
```mlir
func.func @vectorize_scalar_broadcast_column_tensor(%arg0: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<[[0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
%c1 = arith.constant 1 : index
%c1_0 = arith.constant 1 : index
%c4_1 = arith.constant 4 : index
%c0_2 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%0 = vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %c0_i32 : tensor<1x1x4xi32>, vector<1x1x4xi32>
%1 = vector.step : vector<1xindex>
%2 = vector.broadcast %1 : vector<1xindex> to vector<4x1x1xindex>
%3 = vector.transpose %2, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
%4 = vector.step : vector<1xindex>
%5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
%6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
%cst_3 = arith.constant dense<true> : vector<1x1x4xi1>
%cst_4 = arith.constant dense<0> : vector<1x1x4xi32>
%c0_5 = arith.constant 0 : index
%c0_i32_6 = arith.constant 0 : i32
%7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
%8 = vector.extractelement %7[%c0_i32_6 : i32] : vector<4xindex>
%c0_i32_7 = arith.constant 0 : i32
%9 = vector.transfer_read %cst[%8, %c0], %c0_i32_7 {in_bounds = [true, true, true], permutation_map = #map} : tensor<30x1xi32>, vector<1x1x4xi32>
%c0_8 = arith.constant 0 : index
%10 = vector.transfer_write %9, %arg0[%c0_8, %c0_8, %c0_8] : vector<1x1x4xi32>, tensor<1x1x4xi32>
return %10 : tensor<1x1x4xi32>
}
```
So, we'd need to match:
```mlir
%4 = vector.step : vector<1xindex>
%5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
%6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
%dim = tensor.dim %cst, %c1_6 : tensor<30x1xi32>
%7 = vector.broadcast %dim : index to vector<1x1x4xindex>
%8 = arith.muli %6, %7 : vector<1x1x4xindex>
%cst_7 = arith.constant dense<0> : vector<1x1x4xindex>
%9 = arith.addi %cst_7, %8 : vector<1x1x4xindex>
%10 = vector.gather %cst[%c0_5, %c0_5] [%9], %cst_3, %cst_4 : tensor<30x1xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
```
Not that bad, but that's "roughly" what the vectoriser matches today to decide that the underlying `tensor.extract` is a broadcast of a scalar. IIUC, you are suggesting removing that logic from the vectorizer and creating a Vector dialect pattern instead? Would you go as far as simplifying the vectorizer to always generate `vector.gather` and the let Vector patterns "lower" that to something more efficient?
https://github.com/llvm/llvm-project/pull/99299
More information about the Mlir-commits
mailing list