[all-commits] [llvm/llvm-project] e22508: [mlir][vector] Update `CombineContractBroadcastMas...
Andrzej Warzyński via All-commits
all-commits at lists.llvm.org
Tue May 27 05:34:38 PDT 2025
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: e22508ea8111a13d652f7a0e68a556794bfae519
https://github.com/llvm/llvm-project/commit/e22508ea8111a13d652f7a0e68a556794bfae519
Author: Andrzej Warzyński <andrzej.warzynski at arm.com>
Date: 2025-05-27 (Tue, 27 May 2025)
Changed paths:
M mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
M mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
Log Message:
-----------
[mlir][vector] Update `CombineContractBroadcastMask` (#140050)
This patch updates `CombineContractBroadcastMask` to inherit from
`MaskableOpRewritePattern`, enabling it to handle masked
`vector.contract` operations. The pattern rewrites:
```mlir
%a = vector.broadcast %a_bc
%res vector.contract %a_bc, %b, ...
```
into:
```mlir
// Move the broadcast into vector.contract (by updating the indexing
// maps)
%res vector.contract %a, %b, ...
```
The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
%arg0 : vector<8x4xi32>,
%arg1 : vector<8x4xi32>,
%arg2 : vector<8x8xi32>,
%mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
%1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
%result = vector.mask %mask {
vector.contract {
indexing_maps = [#map0, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
} : vector<1x8x8x4xi1> -> vector<8x8xi32>
return %result : vector<8x8xi32>
}
```
Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a `vector.shape_cast`:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
%arg0: vector<8x4xi32>,
%arg1: vector<8x4xi32>,
%arg2: vector<8x8xi32>,
%arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1>
%res = vector.mask %mask_sc {
vector.contract {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
} : vector<8x8x4xi1> -> vector<8x8xi32>
return %res : vector<8x8xi32>
}
```
While this isn't ideal - since it introduces a `vector.shape_cast` that
must be cleaned up later - it reflects the best we can do once the input
reaches `CombineContractBroadcastMask`. A more robust solution may
involve simplifying the input earlier. I am leaving that as a TODO for
myself to explore this further. Posting this now to unblock downstream
work.
LIMITATIONS
Currently, this pattern assumes:
* Only leading dimensions are dropped in the mask.
* All dropped dimensions must be unit-sized.
To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications
More information about the All-commits
mailing list